Initial commit on Unstract
This commit is contained in:
35
.github/pull_request_template.md
vendored
Normal file
35
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
## What
|
||||
|
||||
...
|
||||
|
||||
## Why
|
||||
|
||||
...
|
||||
|
||||
## How
|
||||
|
||||
...
|
||||
|
||||
## Relevant Docs
|
||||
|
||||
-
|
||||
|
||||
## Related Issues or PRs
|
||||
|
||||
-
|
||||
|
||||
## Dependencies Versions / Env Variables
|
||||
|
||||
-
|
||||
|
||||
## Notes on Testing
|
||||
|
||||
...
|
||||
|
||||
## Screenshots
|
||||
|
||||
...
|
||||
|
||||
## Checklist
|
||||
|
||||
I have read and understood the [Contribution Guidelines]().
|
||||
86
.github/workflows/ci-container-build.yaml
vendored
Normal file
86
.github/workflows/ci-container-build.yaml
vendored
Normal file
@@ -0,0 +1,86 @@
|
||||
name: Container Image Build Test for PRs
|
||||
|
||||
env:
|
||||
VERSION: ci-test # Used for docker tag
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- development
|
||||
paths:
|
||||
- 'backend/**'
|
||||
- 'frontend/**'
|
||||
- 'unstract/**'
|
||||
- 'document-service/**'
|
||||
- 'platform-service/**'
|
||||
- 'x2text-service/**'
|
||||
- 'worker/**'
|
||||
- 'docker/dockerfiles/**'
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
branches:
|
||||
- main
|
||||
- development
|
||||
paths:
|
||||
- 'backend/**'
|
||||
- 'frontend/**'
|
||||
- 'unstract/**'
|
||||
- 'document-service/**'
|
||||
- 'platform-service/**'
|
||||
- 'x2text-service/**'
|
||||
- 'worker/**'
|
||||
- 'docker/dockerfiles/**'
|
||||
|
||||
jobs:
|
||||
build:
|
||||
if: github.event.pull_request.draft == false
|
||||
runs-on: custom-k8s-runner
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Container Build
|
||||
working-directory: ./docker
|
||||
run: |
|
||||
docker compose -f docker-compose.build.yaml build
|
||||
- name: Container Run
|
||||
working-directory: ./docker
|
||||
run: |
|
||||
cp ../backend/sample.env ../backend/.env
|
||||
cp ../document-service/sample.env ../document-service/.env
|
||||
cp ../platform-service/sample.env ../platform-service/.env
|
||||
cp ../prompt-service/sample.env ../prompt-service/.env
|
||||
cp ../worker/sample.env ../worker/.env
|
||||
cp ../x2text-service/sample.env ../x2text-service/.env
|
||||
cp sample.essentials.env essentials.env
|
||||
|
||||
docker compose -f docker-compose.yaml up -d
|
||||
sleep 10
|
||||
docker compose -f docker-compose.yaml ps -a
|
||||
# Get the names of exited containers
|
||||
custom_format="{{.Name}}\t{{.Image}}\t{{.Service}}"
|
||||
EXITED_CONTAINERS=$(docker compose -f docker-compose.yaml ps -a --filter status=exited --format "$custom_format")
|
||||
|
||||
line_count=$(echo "$EXITED_CONTAINERS" | wc -l)
|
||||
|
||||
if [ "$line_count" -gt 1 ]; then
|
||||
echo "Exited Containers: $EXITED_CONTAINERS"
|
||||
|
||||
SERVICE=$(echo "$EXITED_CONTAINERS" | awk 'NR>0 {print $3}')
|
||||
echo "Exited Services:"
|
||||
echo "$SERVICE"
|
||||
echo "There are exited containers."
|
||||
# Print logs of exited containers
|
||||
IFS=$'\n'
|
||||
for SERVICE in $SERVICE; do
|
||||
docker compose -f docker-compose.yaml logs "$SERVICE"
|
||||
done
|
||||
docker compose -f docker-compose.yaml down -v
|
||||
exit 1
|
||||
fi
|
||||
docker compose -f docker-compose.yaml down -v
|
||||
64
.github/workflows/docker-build-push-dev.yaml
vendored
Normal file
64
.github/workflows/docker-build-push-dev.yaml
vendored
Normal file
@@ -0,0 +1,64 @@
|
||||
name: Unstract Docker Image Build and Push (Development)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: "Docker image tag"
|
||||
required: true
|
||||
default: "latest"
|
||||
service_name:
|
||||
description: "Service to build"
|
||||
required: true
|
||||
default: "backend" # Provide a default value
|
||||
type: choice
|
||||
options: # Define available options
|
||||
- all-services
|
||||
- frontend
|
||||
- backend
|
||||
- document-service
|
||||
- platform-service
|
||||
- worker
|
||||
- prompt-service
|
||||
- x2text-service
|
||||
|
||||
run-name: "[${{ inputs.service_name }}] Docker Image Build and Push (Development)"
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: custom-k8s-runner
|
||||
steps:
|
||||
- name: Output Inputs
|
||||
run: echo "${{ toJSON(github.event.inputs) }}"
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# Build and push Docker image for the specified service
|
||||
- name: Build and push image for ${{ github.event.inputs.service_name }}
|
||||
working-directory: ./docker
|
||||
if: github.event.inputs.service_name != 'all-services'
|
||||
run: |
|
||||
VERSION=${{ github.event.inputs.tag }} docker compose -f docker-compose.build.yaml build --no-cache ${{ github.event.inputs.service_name }}
|
||||
docker push unstract/${{ github.event.inputs.service_name }}:${{ github.event.inputs.tag }}
|
||||
|
||||
# Build and push all service images
|
||||
- name: Build and push all images
|
||||
working-directory: ./docker
|
||||
if: github.event.inputs.service_name == 'all-services'
|
||||
run: |
|
||||
VERSION=${{ github.event.inputs.tag }} docker-compose -f docker-compose.build.yaml build --no-cache
|
||||
# Push all built images
|
||||
docker push unstract/backend:${{ github.event.inputs.tag }}
|
||||
docker push unstract/frontend:${{ github.event.inputs.tag }}
|
||||
docker push unstract/document-service:${{ github.event.inputs.tag }}
|
||||
docker push unstract/platform-service:${{ github.event.inputs.tag }}
|
||||
docker push unstract/worker:${{ github.event.inputs.tag }}
|
||||
docker push unstract/prompt-service:${{ github.event.inputs.tag }}
|
||||
docker push unstract/x2text-service:${{ github.event.inputs.tag }}
|
||||
65
.github/workflows/docker-tools-build-push.yaml
vendored
Normal file
65
.github/workflows/docker-tools-build-push.yaml
vendored
Normal file
@@ -0,0 +1,65 @@
|
||||
name: Unstract Tools Docker Image Build and Push (Development)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: "Docker image tag"
|
||||
required: true
|
||||
default: "latest"
|
||||
service_name:
|
||||
description: "Tool to build"
|
||||
required: true
|
||||
default: "tool-classifier" # Provide a default value
|
||||
type: choice
|
||||
options: # Define available options
|
||||
- tool-classifier
|
||||
- tool-doc-pii-redactor
|
||||
- tool-indexer
|
||||
- tool-ocr
|
||||
- tool-translate
|
||||
- tool-structure
|
||||
- tool-text-extractor
|
||||
|
||||
run-name: "[${{ inputs.service_name }}] Docker Image Build and Push (Development)"
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: custom-k8s-runner
|
||||
steps:
|
||||
- name: Output Inputs
|
||||
run: echo "${{ toJSON(github.event.inputs) }}"
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Build tool-classifier
|
||||
if: github.event.inputs.service_name=='tool-classifier'
|
||||
run: docker build -t unstract/${{github.event.inputs.service_name}}:${{ github.event.inputs.tag }} ./tools/classifier
|
||||
- name: Build tool-doc-pii-redactor
|
||||
if: github.event.inputs.service_name=='tool-doc-pii-redactor'
|
||||
run: docker build -t unstract/${{github.event.inputs.service_name}}:${{ github.event.inputs.tag }} ./tools/doc_pii_redactor
|
||||
- name: Build tool-indexer
|
||||
if: github.event.inputs.service_name=='tool-indexer'
|
||||
run: docker build -t unstract/${{github.event.inputs.service_name}}:${{ github.event.inputs.tag }} ./tools/indexer
|
||||
- name: Build tool-ocr
|
||||
if: github.event.inputs.service_name=='tool-ocr'
|
||||
run: docker build -t unstract/${{github.event.inputs.service_name}}:${{ github.event.inputs.tag }} ./tools/ocr
|
||||
- name: Build tool-translate
|
||||
if: github.event.inputs.service_name=='tool-translate'
|
||||
run: docker build -t unstract/${{github.event.inputs.service_name}}:${{ github.event.inputs.tag }} ./tools/translate
|
||||
- name: Build tool-structure
|
||||
if: github.event.inputs.service_name=='tool-structure'
|
||||
run: docker build -t unstract/${{github.event.inputs.service_name}}:${{ github.event.inputs.tag }} ./tools/structure
|
||||
- name: Build tool-text-extractor
|
||||
if: github.event.inputs.service_name=='tool-text-extractor'
|
||||
run: docker build -t unstract/${{github.event.inputs.service_name}}:${{ github.event.inputs.tag }} ./tools/text_extractor
|
||||
|
||||
- name: Push Docker image to Docker Hub
|
||||
run: docker push unstract/${{ github.event.inputs.service_name }}:${{ github.event.inputs.tag }}
|
||||
33
.github/workflows/production-build.yaml
vendored
Normal file
33
.github/workflows/production-build.yaml
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
name: Unstract Docker Image Build and Push (Production)
|
||||
|
||||
on:
|
||||
release:
|
||||
types:
|
||||
- created
|
||||
|
||||
run-name: "[${{ github.event.release.tag_name }}] Docker Image Build and Push (Development)"
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: custom-k8s-runner
|
||||
strategy:
|
||||
matrix:
|
||||
service_name: [backend, frontend, document-service, platform-service, prompt-service, worker, x2text-service]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
ref: ${{ github.event.release.tag_name }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Build and push image
|
||||
working-directory: ./docker
|
||||
run: |
|
||||
VERSION=${{ github.event.release.tag_name }} docker-compose -f docker-compose.build.yaml build --no-cache ${{ matrix.service_name }}
|
||||
docker push unstract/${{ matrix.service_name }}:${{ github.event.release.tag_name }}
|
||||
631
.gitignore
vendored
Normal file
631
.gitignore
vendored
Normal file
@@ -0,0 +1,631 @@
|
||||
# Created by https://www.toptal.com/developers/gitignore/api/windows,macos,linux,pycharm,pycharm+all,pycharm+iml,python,visualstudiocode,react,django
|
||||
# Edit at https://www.toptal.com/developers/gitignore?templates=windows,macos,linux,pycharm,pycharm+all,pycharm+iml,python,visualstudiocode,react,django
|
||||
|
||||
### Django ###
|
||||
*.log
|
||||
*.pot
|
||||
*.pyc
|
||||
__pycache__/
|
||||
local_settings.py
|
||||
media
|
||||
|
||||
# If your build process includes running collectstatic, then you probably don't need or want to include staticfiles/
|
||||
# in your Git repository. Update and uncomment the following line accordingly.
|
||||
# <django-project-name>/staticfiles/
|
||||
|
||||
### Django.Python Stack ###
|
||||
# Byte-compiled / optimized / DLL files
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
|
||||
# Django stuff:
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
.pdm-build
|
||||
.pdm-python
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.env.export
|
||||
.venv*
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
### Linux ###
|
||||
*~
|
||||
|
||||
# temporary files which can be created if a process still has a handle open of a deleted file
|
||||
.fuse_hidden*
|
||||
|
||||
# KDE directory preferences
|
||||
.directory
|
||||
|
||||
# Linux trash folder which might appear on any partition or disk
|
||||
.Trash-*
|
||||
|
||||
# .nfs files are created when an open file is removed but is still being accessed
|
||||
.nfs*
|
||||
|
||||
### macOS ###
|
||||
# General
|
||||
.DS_Store
|
||||
.AppleDouble
|
||||
.LSOverride
|
||||
|
||||
# Icon must end with two \r
|
||||
Icon
|
||||
|
||||
|
||||
# Thumbnails
|
||||
._*
|
||||
|
||||
# Files that might appear in the root of a volume
|
||||
.DocumentRevisions-V100
|
||||
.fseventsd
|
||||
.Spotlight-V100
|
||||
.TemporaryItems
|
||||
.Trashes
|
||||
.VolumeIcon.icns
|
||||
.com.apple.timemachine.donotpresent
|
||||
|
||||
# Directories potentially created on remote AFP share
|
||||
.AppleDB
|
||||
.AppleDesktop
|
||||
Network Trash Folder
|
||||
Temporary Items
|
||||
.apdisk
|
||||
|
||||
### macOS Patch ###
|
||||
# iCloud generated files
|
||||
*.icloud
|
||||
|
||||
### PyCharm ###
|
||||
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
||||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
||||
|
||||
# User-specific stuff
|
||||
.idea/**/workspace.xml
|
||||
.idea/**/tasks.xml
|
||||
.idea/**/usage.statistics.xml
|
||||
.idea/**/dictionaries
|
||||
.idea/**/shelf
|
||||
|
||||
# AWS User-specific
|
||||
.idea/**/aws.xml
|
||||
|
||||
# Generated files
|
||||
.idea/**/contentModel.xml
|
||||
|
||||
# Sensitive or high-churn files
|
||||
.idea/**/dataSources/
|
||||
.idea/**/dataSources.ids
|
||||
.idea/**/dataSources.local.xml
|
||||
.idea/**/sqlDataSources.xml
|
||||
.idea/**/dynamic.xml
|
||||
.idea/**/uiDesigner.xml
|
||||
.idea/**/dbnavigator.xml
|
||||
|
||||
# Gradle
|
||||
.idea/**/gradle.xml
|
||||
.idea/**/libraries
|
||||
|
||||
# Gradle and Maven with auto-import
|
||||
# When using Gradle or Maven with auto-import, you should exclude module files,
|
||||
# since they will be recreated, and may cause churn. Uncomment if using
|
||||
# auto-import.
|
||||
# .idea/artifacts
|
||||
# .idea/compiler.xml
|
||||
# .idea/jarRepositories.xml
|
||||
# .idea/modules.xml
|
||||
# .idea/*.iml
|
||||
# .idea/modules
|
||||
# *.iml
|
||||
# *.ipr
|
||||
|
||||
# CMake
|
||||
cmake-build-*/
|
||||
|
||||
# Mongo Explorer plugin
|
||||
.idea/**/mongoSettings.xml
|
||||
|
||||
# File-based project format
|
||||
*.iws
|
||||
|
||||
# IntelliJ
|
||||
out/
|
||||
|
||||
# mpeltonen/sbt-idea plugin
|
||||
.idea_modules/
|
||||
|
||||
# JIRA plugin
|
||||
atlassian-ide-plugin.xml
|
||||
|
||||
# Cursive Clojure plugin
|
||||
.idea/replstate.xml
|
||||
|
||||
# SonarLint plugin
|
||||
.idea/sonarlint/
|
||||
|
||||
# Crashlytics plugin (for Android Studio and IntelliJ)
|
||||
com_crashlytics_export_strings.xml
|
||||
crashlytics.properties
|
||||
crashlytics-build.properties
|
||||
fabric.properties
|
||||
|
||||
# Editor-based Rest Client
|
||||
.idea/httpRequests
|
||||
|
||||
# Android studio 3.1+ serialized cache file
|
||||
.idea/caches/build_file_checksums.ser
|
||||
|
||||
### PyCharm Patch ###
|
||||
# Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
|
||||
|
||||
# *.iml
|
||||
# modules.xml
|
||||
# .idea/misc.xml
|
||||
# *.ipr
|
||||
|
||||
# Sonarlint plugin
|
||||
# https://plugins.jetbrains.com/plugin/7973-sonarlint
|
||||
.idea/**/sonarlint/
|
||||
|
||||
# SonarQube Plugin
|
||||
# https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin
|
||||
.idea/**/sonarIssues.xml
|
||||
|
||||
# Markdown Navigator plugin
|
||||
# https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced
|
||||
.idea/**/markdown-navigator.xml
|
||||
.idea/**/markdown-navigator-enh.xml
|
||||
.idea/**/markdown-navigator/
|
||||
|
||||
# Cache file creation bug
|
||||
# See https://youtrack.jetbrains.com/issue/JBR-2257
|
||||
.idea/$CACHE_FILE$
|
||||
|
||||
# CodeStream plugin
|
||||
# https://plugins.jetbrains.com/plugin/12206-codestream
|
||||
.idea/codestream.xml
|
||||
|
||||
# Azure Toolkit for IntelliJ plugin
|
||||
# https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij
|
||||
.idea/**/azureSettings.xml
|
||||
|
||||
### PyCharm+all ###
|
||||
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
||||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
||||
|
||||
# User-specific stuff
|
||||
|
||||
# AWS User-specific
|
||||
|
||||
# Generated files
|
||||
|
||||
# Sensitive or high-churn files
|
||||
|
||||
# Gradle
|
||||
|
||||
# Gradle and Maven with auto-import
|
||||
# When using Gradle or Maven with auto-import, you should exclude module files,
|
||||
# since they will be recreated, and may cause churn. Uncomment if using
|
||||
# auto-import.
|
||||
# .idea/artifacts
|
||||
# .idea/compiler.xml
|
||||
# .idea/jarRepositories.xml
|
||||
# .idea/modules.xml
|
||||
# .idea/*.iml
|
||||
# .idea/modules
|
||||
# *.iml
|
||||
# *.ipr
|
||||
|
||||
# CMake
|
||||
|
||||
# Mongo Explorer plugin
|
||||
|
||||
# File-based project format
|
||||
|
||||
# IntelliJ
|
||||
|
||||
# mpeltonen/sbt-idea plugin
|
||||
|
||||
# JIRA plugin
|
||||
|
||||
# Cursive Clojure plugin
|
||||
|
||||
# SonarLint plugin
|
||||
|
||||
# Crashlytics plugin (for Android Studio and IntelliJ)
|
||||
|
||||
# Editor-based Rest Client
|
||||
|
||||
# Android studio 3.1+ serialized cache file
|
||||
|
||||
### PyCharm+all Patch ###
|
||||
# Ignore everything but code style settings and run configurations
|
||||
# that are supposed to be shared within teams.
|
||||
|
||||
.idea/*
|
||||
|
||||
!.idea/codeStyles
|
||||
!.idea/runConfigurations
|
||||
|
||||
### PyCharm+iml ###
|
||||
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
||||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
||||
|
||||
# User-specific stuff
|
||||
|
||||
# AWS User-specific
|
||||
|
||||
# Generated files
|
||||
|
||||
# Sensitive or high-churn files
|
||||
|
||||
# Gradle
|
||||
|
||||
# Gradle and Maven with auto-import
|
||||
# When using Gradle or Maven with auto-import, you should exclude module files,
|
||||
# since they will be recreated, and may cause churn. Uncomment if using
|
||||
# auto-import.
|
||||
# .idea/artifacts
|
||||
# .idea/compiler.xml
|
||||
# .idea/jarRepositories.xml
|
||||
# .idea/modules.xml
|
||||
# .idea/*.iml
|
||||
# .idea/modules
|
||||
# *.iml
|
||||
# *.ipr
|
||||
|
||||
# CMake
|
||||
|
||||
# Mongo Explorer plugin
|
||||
|
||||
# File-based project format
|
||||
|
||||
# IntelliJ
|
||||
|
||||
# mpeltonen/sbt-idea plugin
|
||||
|
||||
# JIRA plugin
|
||||
|
||||
# Cursive Clojure plugin
|
||||
|
||||
# SonarLint plugin
|
||||
|
||||
# Crashlytics plugin (for Android Studio and IntelliJ)
|
||||
|
||||
# Editor-based Rest Client
|
||||
|
||||
# Android studio 3.1+ serialized cache file
|
||||
|
||||
### PyCharm+iml Patch ###
|
||||
# Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023
|
||||
|
||||
*.iml
|
||||
modules.xml
|
||||
.idea/misc.xml
|
||||
*.ipr
|
||||
|
||||
### Python ###
|
||||
# Byte-compiled / optimized / DLL files
|
||||
|
||||
# C extensions
|
||||
|
||||
# Distribution / packaging
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
|
||||
# Installer logs
|
||||
|
||||
# Unit test / coverage reports
|
||||
|
||||
# Translations
|
||||
|
||||
# Django stuff:
|
||||
|
||||
# Flask stuff:
|
||||
|
||||
# Scrapy stuff:
|
||||
|
||||
# Sphinx documentation
|
||||
|
||||
# PyBuilder
|
||||
|
||||
# Jupyter Notebook
|
||||
|
||||
# IPython
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
|
||||
# Celery stuff
|
||||
|
||||
# SageMath parsed files
|
||||
|
||||
# Environments
|
||||
|
||||
# Spyder project settings
|
||||
|
||||
# Rope project settings
|
||||
|
||||
# mkdocs documentation
|
||||
|
||||
# mypy
|
||||
|
||||
# Pyre type checker
|
||||
|
||||
# pytype static type analyzer
|
||||
|
||||
# Cython debug symbols
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
|
||||
### Python Patch ###
|
||||
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
|
||||
poetry.toml
|
||||
|
||||
# ruff
|
||||
.ruff_cache/
|
||||
|
||||
# LSP config files
|
||||
pyrightconfig.json
|
||||
|
||||
### react ###
|
||||
.DS_*
|
||||
logs
|
||||
**/*.backup.*
|
||||
**/*.back.*
|
||||
|
||||
node_modules
|
||||
bower_components
|
||||
|
||||
*.sublime*
|
||||
|
||||
psd
|
||||
thumb
|
||||
sketch
|
||||
|
||||
### VisualStudioCode ###
|
||||
.vscode/
|
||||
**/.vscode/*
|
||||
!.vscode/settings.json
|
||||
!.vscode/tasks.json
|
||||
# !.vscode/launch.json
|
||||
!.vscode/extensions.json
|
||||
!.vscode/*.code-snippets
|
||||
|
||||
# Local History for Visual Studio Code
|
||||
.history/
|
||||
|
||||
# Built Visual Studio Code Extensions
|
||||
*.vsix
|
||||
|
||||
### VisualStudioCode Patch ###
|
||||
# Ignore all local history of files
|
||||
.history
|
||||
.ionide
|
||||
|
||||
### Windows ###
|
||||
# Windows thumbnail cache files
|
||||
Thumbs.db
|
||||
Thumbs.db:encryptable
|
||||
ehthumbs.db
|
||||
ehthumbs_vista.db
|
||||
|
||||
# Dump file
|
||||
*.stackdump
|
||||
|
||||
# Folder config file
|
||||
[Dd]esktop.ini
|
||||
|
||||
# Recycle Bin used on file shares
|
||||
$RECYCLE.BIN/
|
||||
|
||||
# Windows Installer files
|
||||
*.cab
|
||||
*.msi
|
||||
*.msix
|
||||
*.msm
|
||||
*.msp
|
||||
|
||||
# Windows shortcuts
|
||||
*.lnk
|
||||
|
||||
### Unstract ###
|
||||
|
||||
# Plugins
|
||||
backend/plugins/authentication/*
|
||||
!backend/plugins/authentication/auth_sample
|
||||
|
||||
# Tool registry
|
||||
unstract/tool-registry/src/unstract/tool_registry/*.json
|
||||
unstract/tool-registry/tests/*.yaml
|
||||
!unstract/tool-registry/src/unstract/tool_registry/public_tools.json
|
||||
unstract/tool-registry/src/unstract/tool_registry/config/registry.yaml
|
||||
|
||||
# Docker related
|
||||
# End of https://www.toptal.com/developers/gitignore/api/windows,macos,linux,pycharm,pycharm+all,pycharm+iml,python,visualstudiocode,react,django
|
||||
docker/temp/*
|
||||
docker/init.sql/*
|
||||
docker/*.env
|
||||
!docker/sample.*.env
|
||||
docker/public_tools.json
|
||||
docker/proxy_overrides.yaml
|
||||
|
||||
# Tool development
|
||||
tools/*/sdks/
|
||||
tools/*/data_dir/
|
||||
|
||||
docker/workflow_data/
|
||||
|
||||
|
||||
200
.pre-commit-config.yaml
Normal file
200
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,200 @@
|
||||
---
|
||||
# See https://pre-commit.com for more information
|
||||
# See https://pre-commit.com/hooks.html for more hooks
|
||||
# - Added unstract feature flag auto generated code to flake8 exclude list
|
||||
# Force all unspecified python hooks to run python 3.10
|
||||
default_language_version:
|
||||
python: python3.9
|
||||
default_stages:
|
||||
- commit
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.5.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
exclude_types:
|
||||
- "markdown"
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
args: [--unsafe]
|
||||
- id: check-added-large-files
|
||||
args: ["--maxkb=10240"]
|
||||
- id: check-case-conflict
|
||||
- id: check-docstring-first
|
||||
- id: check-ast
|
||||
- id: check-json
|
||||
exclude: ".vscode/launch.json"
|
||||
- id: check-executables-have-shebangs
|
||||
- id: check-shebang-scripts-are-executable
|
||||
- id: check-toml
|
||||
- id: debug-statements
|
||||
- id: detect-private-key
|
||||
- id: check-merge-conflict
|
||||
- id: check-symlinks
|
||||
- id: destroyed-symlinks
|
||||
- id: forbid-new-submodules
|
||||
- id: mixed-line-ending
|
||||
- id: no-commit-to-branch
|
||||
- repo: https://github.com/adrienverge/yamllint
|
||||
rev: v1.35.1
|
||||
hooks:
|
||||
- id: yamllint
|
||||
args: ["-d", "relaxed"]
|
||||
language: system
|
||||
- repo: https://github.com/rhysd/actionlint
|
||||
rev: v1.6.26
|
||||
hooks:
|
||||
- id: actionlint-docker
|
||||
args: [-ignore, 'label ".+" is unknown']
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 24.2.0
|
||||
hooks:
|
||||
- id: black
|
||||
args: [--config=pyproject.toml, -l 80]
|
||||
language: system
|
||||
exclude: |
|
||||
(?x)^(
|
||||
unstract/flags/src/unstract/flags/evaluation_.*\.py|
|
||||
)$
|
||||
- repo: https://github.com/pycqa/flake8
|
||||
rev: 7.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
args: [--max-line-length=80]
|
||||
exclude: |
|
||||
(?x)^(
|
||||
.*migrations/.*\.py|
|
||||
core/tests/.*|
|
||||
unstract/flags/src/unstract/flags/evaluation_.*\.py|
|
||||
)$
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.13.2
|
||||
hooks:
|
||||
- id: isort
|
||||
files: "\\.(py)$"
|
||||
args:
|
||||
[
|
||||
"--profile",
|
||||
"black",
|
||||
"--filter-files",
|
||||
--settings-path=pyproject.toml,
|
||||
]
|
||||
- repo: https://github.com/hadialqattan/pycln
|
||||
rev: v2.4.0
|
||||
hooks:
|
||||
- id: pycln
|
||||
args: [--config=pyproject.toml]
|
||||
- repo: https://github.com/pycqa/docformatter
|
||||
rev: v1.7.5
|
||||
hooks:
|
||||
- id: docformatter
|
||||
# - repo: https://github.com/MarcoGorelli/absolufy-imports
|
||||
# rev: v0.3.1
|
||||
# hooks:
|
||||
# - id: absolufy-imports
|
||||
# files: ^backend/
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.15.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
entry: pyupgrade --py39-plus --keep-runtime-typing
|
||||
types:
|
||||
- python
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.18.2
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
- repo: https://github.com/hadolint/hadolint
|
||||
rev: v2.12.1-beta
|
||||
hooks:
|
||||
- id: hadolint-docker
|
||||
args:
|
||||
- --ignore=DL3003
|
||||
- --ignore=DL3008
|
||||
- --ignore=DL3013
|
||||
- --ignore=DL3018
|
||||
- --ignore=SC1091
|
||||
files: Dockerfile$
|
||||
- repo: https://github.com/asottile/yesqa
|
||||
rev: v1.5.0
|
||||
hooks:
|
||||
- id: yesqa
|
||||
- repo: https://github.com/pre-commit/mirrors-eslint
|
||||
rev: "v9.0.0-beta.0" # Use the sha / tag you want to point at
|
||||
hooks:
|
||||
- id: eslint
|
||||
args: [--config=frontend/.eslintrc.json]
|
||||
files: \.[jt]sx?$ # *.js, *.jsx, *.ts and *.tsx
|
||||
types: [file]
|
||||
additional_dependencies:
|
||||
- eslint@8.41.0
|
||||
- eslint-config-google@0.14.0
|
||||
- eslint-config-prettier@8.8.0
|
||||
- eslint-plugin-prettier@4.2.1
|
||||
- eslint-plugin-react@7.32.2
|
||||
- eslint-plugin-import@2.25.2
|
||||
- repo: https://github.com/Lucas-C/pre-commit-hooks-nodejs
|
||||
rev: v1.1.2
|
||||
hooks:
|
||||
- id: htmlhint
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.15.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
entry: pyupgrade --py38-plus --keep-runtime-typing
|
||||
types:
|
||||
- python
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.8.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
language: system
|
||||
entry: mypy .
|
||||
pass_filenames: false
|
||||
# IMPORTANT!
|
||||
# Keep args same as tool.mypy section in pyproject.toml
|
||||
args:
|
||||
[
|
||||
--allow-subclassing-any,
|
||||
--allow-untyped-decorators,
|
||||
--check-untyped-defs,
|
||||
--exclude, ".*migrations/.*.py",
|
||||
--exclude, "backend/prompt/.*",
|
||||
--exclude, "document-service/.*",
|
||||
--exclude, "unstract/connectors/tests/.*",
|
||||
--exclude, "unstract/core/.*",
|
||||
--exclude, "unstract/flags/src/unstract/flags/.*",
|
||||
--exclude, "__pypackages__/.*",
|
||||
--follow-imports, "silent",
|
||||
--ignore-missing-imports,
|
||||
--implicit-reexport,
|
||||
--pretty,
|
||||
--python-version=3.9,
|
||||
--show-column-numbers,
|
||||
--show-error-codes,
|
||||
--strict,
|
||||
--warn-redundant-casts,
|
||||
--warn-return-any,
|
||||
--warn-unreachable,
|
||||
--warn-unused-configs,
|
||||
--warn-unused-ignores,
|
||||
]
|
||||
- repo: https://github.com/igorshubovych/markdownlint-cli
|
||||
rev: v0.39.0
|
||||
hooks:
|
||||
- id: markdownlint
|
||||
args: [--disable, MD013]
|
||||
- id: markdownlint-fix
|
||||
args: [--disable, MD013]
|
||||
- repo: https://github.com/pdm-project/pdm
|
||||
rev: 2.12.3
|
||||
hooks:
|
||||
- id: pdm-lock-check
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: check-django-migrations
|
||||
name: Check django migrations
|
||||
entry: sh -c 'pdm run `find . -name "manage.py" -not -path "*/.venv/*"` makemigrations --check --dry-run --no-input'
|
||||
language: system
|
||||
types: [python] # hook only runs if a python file is staged
|
||||
pass_filenames: false
|
||||
3
CONTRIBUTE.md
Normal file
3
CONTRIBUTE.md
Normal file
@@ -0,0 +1,3 @@
|
||||
Conventions
|
||||
|
||||
- Where ever you are adding yaml files, preferred extension is `.yaml`
|
||||
199
README.md
Normal file
199
README.md
Normal file
@@ -0,0 +1,199 @@
|
||||
# Unstract
|
||||
|
||||
[](https://pdm-project.org)
|
||||
|
||||
TODO: Write few lines about the project.
|
||||
|
||||
## System Requirements
|
||||
|
||||
- docker
|
||||
- git
|
||||
|
||||
## Running with docker compose
|
||||
|
||||
- All services needed by the backend can be run with
|
||||
|
||||
```
|
||||
cd docker/
|
||||
VERSION=test docker compose -f docker-compose.build.yaml build
|
||||
VERSION=test docker compose -f docker-compose.yaml up -d
|
||||
```
|
||||
|
||||
Additional information on running with Docker can be found in [DOCKERISING.md](/DOCKERISING.md)
|
||||
|
||||
- Use the `-f` flag to run all dependencies necessary for development, this runs containers needed for testing as well such as Minio.
|
||||
|
||||
```
|
||||
docker compose -f docker-compose-dev-essentials.yaml up
|
||||
```
|
||||
|
||||
- It might take sometime on the first run to pull the images.
|
||||
|
||||
## Running locally
|
||||
|
||||
### Installation
|
||||
|
||||
- Install the below libraries which are needed to run Unstract
|
||||
- Linux
|
||||
|
||||
```
|
||||
sudo apt install build-essential pkg-config libpoppler-cpp-dev libmagic-dev python3-dev
|
||||
```
|
||||
|
||||
- Mac
|
||||
|
||||
```
|
||||
brew install pkg-config poppler freetds libmagic
|
||||
```
|
||||
|
||||
### Create your virtual env
|
||||
|
||||
- In order to install dependencies and run a package, ensure that you've sourced a virtual environment within that package. All commands in this repository assumes that you have sourced your required venv.
|
||||
|
||||
```
|
||||
cd <package_to_use>
|
||||
python -m venv .venv
|
||||
source ./venv/bin/activate
|
||||
```
|
||||
|
||||
|
||||
### Install dependencies with PDM
|
||||
|
||||
- This repository makes use of [PDM](https://github.com/pdm-project/pdm) for managing dependencies with the help of a virtual
|
||||
environment.
|
||||
- If you haven't installed PDM in your machine yet,
|
||||
- Install it using the below command
|
||||
```
|
||||
curl -sSL https://pdm.fming.dev/install-pdm.py | python3 -
|
||||
```
|
||||
- Or install it from PyPI using `pip`
|
||||
```
|
||||
pip install pdm
|
||||
```
|
||||
|
||||
Ensure you're running the PDM commands from the corresponding package root
|
||||
- Install dependencies for running the package with
|
||||
|
||||
```
|
||||
pdm install
|
||||
```
|
||||
This install dev dependencies as well by default
|
||||
- For production, install the requirements with
|
||||
|
||||
```
|
||||
pdm install --prod
|
||||
```
|
||||
|
||||
- With PDM its possible to run some services from any directory within this
|
||||
repository. To list the possible scripts that can be executed
|
||||
```
|
||||
pdm run -l
|
||||
```
|
||||
|
||||
- Add a new dependency with (ensure you're running it from the correct project's root)
|
||||
Perform an editable install with `-e` only for local development.
|
||||
```
|
||||
pdm add <package_from_PyPI>
|
||||
pdm add -e <relative_path_to_local_package>
|
||||
```
|
||||
- List all dependencies with
|
||||
```
|
||||
pdm list
|
||||
```
|
||||
- After updating `pyproject.toml`s with a newly added dependency, the lock file can be updated with
|
||||
```
|
||||
pdm lock
|
||||
```
|
||||
- Refer [PDM's documentation](https://pdm.fming.dev/latest/reference/cli/) for further details.
|
||||
|
||||
### Configuring Postgres
|
||||
|
||||
- Create a Postgres user and DB for the BE and configure it like so
|
||||
|
||||
```
|
||||
POSTGRES_USER: unstract_dev
|
||||
POSTGRES_PASSWORD: unstract_pass
|
||||
POSTGRES_DB: unstract_db
|
||||
```
|
||||
|
||||
If you require a different config, make sure the necessary envs from [backend/sample.env](/backend/sample.env) are exported.
|
||||
|
||||
- Execute the script [backend/init.sql](/backend/init.sql) that adds roles and creates a DB and extension for ZS Document Indexer tool to work.
|
||||
Make sure that [pgvector](https://github.com/pgvector/pgvector#installation) is installed.
|
||||
|
||||
### Pre-commit hooks
|
||||
|
||||
- We use pre-commit to run some hooks whenever code is pushed to perform linting and static code analysis among other checks.
|
||||
- Ensure dev dependencies are installed and you're in the virtual env
|
||||
- Install hooks with `pre-commit install` or `pdm run pre-commit install`
|
||||
- Manually trigger pre-commit hooks in following ways:
|
||||
```bash
|
||||
#
|
||||
# Using the tool directly
|
||||
#
|
||||
# Run all pre-commit hooks
|
||||
pre-commit run
|
||||
# Run specific pre-commit hook
|
||||
pre-commit run flake8
|
||||
# Run mypy pre-commit hook for selected folder
|
||||
pre-commit run mypy --files prompt-service/**/*.py
|
||||
# Run mypy for selected folder
|
||||
mypy prompt-service/**/*.py
|
||||
|
||||
#
|
||||
# Using pdm to run the scripts
|
||||
#
|
||||
# Run all pre-commit hooks
|
||||
pdm run pre-commit run
|
||||
# Run specific pre-commit hook
|
||||
pdm run pre-commit run flake8
|
||||
# Run mypy pre-commit hook for selected folder
|
||||
pdm run pre-commit run mypy --files prompt-service/**/*.py
|
||||
# Run mypy for selected folder
|
||||
pdm run mypy prompt-service/**/*.py
|
||||
```
|
||||
|
||||
### Backend
|
||||
|
||||
- Check [backend/README.md](/backend/README.md) for running the backend.
|
||||
|
||||
### Frontend
|
||||
|
||||
- Install dependencies with `npm install`
|
||||
- Start the server with `npm start`
|
||||
|
||||
### Traefik Proxy Overrides
|
||||
|
||||
It is possible to simultaneously run few services directly on docker host while others are run as docker containers via docker compose.
|
||||
This enables seamless development without worrying about deployment of other services which you are not concerned with.
|
||||
|
||||
We just need to override default Traefik proxy routing to allow this, that's all.
|
||||
|
||||
1. Copy `docker/sample.proxy_overrides.yaml` to `docker/proxy_overrides.yaml`.
|
||||
Modify to update Traefik proxy routes for services running directly on docker host (`host.docker.internal:<port>`).
|
||||
|
||||
2. Update host name of dependency components in config of services running directly on docker host:
|
||||
- Replace as `*.localhost` IF container port is exposed on docker host
|
||||
- **OR** use container IPs obtained via `docker network inspect unstract-network`
|
||||
- **OR** run `dockers/scripts/resolve_container_svc_from_host.sh` IF container port is NOT exposed on docker host or if you want to keep dependency host names unchanged
|
||||
|
||||
Run the services.
|
||||
|
||||
#### Conflicting Host Names
|
||||
|
||||
When same host name environment variables are used by both the service running locally and a service
|
||||
running in a container (for example, running in from a tool), host name resolution conflicts can arise for the following:
|
||||
|
||||
- `localhost` -> Using this inside a container points to the container itself, and not the host.
|
||||
- `host.docker.internal` -> Meant to be used inside containers only, to get host IP.
|
||||
Does not make sense to use in services running locally.
|
||||
|
||||
*In such cases, use another host name and point the same to host IP in `/etc/hosts`.*
|
||||
|
||||
For example, the backend uses the PROMPT_HOST environment variable, which is also supplied
|
||||
in the Tool configuration when spawning Tool containers. If the backend is running
|
||||
locally and the Tools are in containers, we could set the value to
|
||||
`prompt-service` and add it to `/etc/hosts` as shown below.
|
||||
```
|
||||
<host_local_ip> prompt-service
|
||||
```
|
||||
136
backend/README.md
Normal file
136
backend/README.md
Normal file
@@ -0,0 +1,136 @@
|
||||
# Unstract Backend
|
||||
|
||||
Contains the backend services for Unstract written with Django and DRF.
|
||||
|
||||
## Dependencies
|
||||
|
||||
1. Postgres
|
||||
1. Redis
|
||||
|
||||
## Getting started
|
||||
**NOTE**: All commands are executed from `/backend` and require the venv to be active. Refer [these steps](/README.md#create-your-virtual-env) to create/activate your venv
|
||||
|
||||
### Install and run manually
|
||||
|
||||
- Ensure that you've sourced your virtual environment and installed dependencies mentioned [here](/README.md#create-your-virtual-env).
|
||||
|
||||
- If you plan to run the django server locally, make sure the dependent services are up (either locally or through docker compose)
|
||||
- Copy `sample.env` into `.env` and update the necessary variables. For eg:
|
||||
|
||||
```
|
||||
DJANGO_SETTINGS_MODULE='backend.settings.dev'
|
||||
DB_HOST='localhost'
|
||||
DB_USER='unstract_dev'
|
||||
DB_PASSWORD='unstract_pass'
|
||||
DB_NAME='unstract_db'
|
||||
DB_PORT=5432
|
||||
```
|
||||
|
||||
- If you've made changes to the model, run `python manage.py makemigrations`, else ignore this step
|
||||
- Run the following to apply any migrations to the DB and start the server
|
||||
|
||||
```
|
||||
python manage.py migrate
|
||||
python manage.py runserver localhost:8000
|
||||
```
|
||||
|
||||
- Server will start and run at port 8000. (<http://localhost:8000>)
|
||||
|
||||
## Asynchronous execution/pipeline execution
|
||||
|
||||
- Working with celery
|
||||
- Each pipeline or shared tasks will added to the queue (Redis), And the worker will consume from the queue
|
||||
|
||||
### Run Execution Worker
|
||||
|
||||
Run the following command to start the worker:
|
||||
|
||||
```bash
|
||||
celery -A backend worker --loglevel=info
|
||||
```
|
||||
|
||||
### Worker Dashboard
|
||||
|
||||
- We have to ensure the package flower is installed in the current environment
|
||||
- Run command
|
||||
|
||||
```bash
|
||||
celery -A backend flower
|
||||
```
|
||||
This command will start Flower on the default port (5555) and can be accessed via a web browser. Flower provides a user-friendly interface for monitoring and managing Celery tasks
|
||||
|
||||
|
||||
## Connecting to Postgres
|
||||
|
||||
Follow the below steps to connect to the postgres DB running with `docker compose`.
|
||||
|
||||
1. Exec into a shell within the postgres container
|
||||
|
||||
```
|
||||
docker compose exec -it db bash
|
||||
```
|
||||
|
||||
2. Connect to the db as the specified user
|
||||
|
||||
```
|
||||
psql -d unstract_db -U unstract_dev
|
||||
```
|
||||
|
||||
3. Execute PSQL commands within the shell.
|
||||
|
||||
## API Docs
|
||||
|
||||
While running the backend server locally, access the API documentation that's auto generated at
|
||||
the backend endpoint `/api/v1/doc/`.
|
||||
|
||||
**NOTE:** There exists issues accessing this when the django server is run with gunicorn (in case of running with
|
||||
a container)
|
||||
|
||||
- [Account](account/api_doc.md)
|
||||
- [FileManagement](file_management/api_doc.md)
|
||||
|
||||
## Connectors
|
||||
|
||||
### Google Drive
|
||||
The Google Drive connector makes use of [PyDrive2](https://pypi.org/project/PyDrive2/) library and supports only OAuth 2.0 authentication.
|
||||
To set it up, follow the first step higlighted in [Google's docs](https://developers.google.com/identity/protocols/oauth2#1.-obtain-oauth-2.0-credentials-from-the-dynamic_data.setvar.console_name-.) and set the client ID and client secret
|
||||
as envs in `backend/.env`
|
||||
```
|
||||
GOOGLE_OAUTH2_KEY="<client-id>"
|
||||
GOOGLE_OAUTH2_SECRET="<client-secret>"
|
||||
```
|
||||
|
||||
|
||||
# Archived - (EXPERIMENTAL)
|
||||
|
||||
## Accessing the admin site
|
||||
|
||||
- If its the first time, create a super user and follow the on-screen instructions
|
||||
|
||||
```
|
||||
python manage.py createsuperuser
|
||||
```
|
||||
|
||||
- Register your models in `<app>/admin.py`, for example
|
||||
|
||||
```
|
||||
from django.contrib import admin
|
||||
from .models import Prompt
|
||||
|
||||
admin.site.register(Prompt)
|
||||
```
|
||||
|
||||
- Make sure the server is running and hit the `/admin` endpoint
|
||||
|
||||
## Running unit tests
|
||||
|
||||
Units tests are run with [pytest](https://docs.pytest.org/en/7.3.x/) and [pytest-django](https://pytest-django.readthedocs.io/en/latest/index.html)
|
||||
|
||||
```
|
||||
pytest
|
||||
pytest prompt # To run for an app named prompt
|
||||
```
|
||||
|
||||
All tests are organized within an app, for eg: `prompt/tests/test_urls.py`
|
||||
|
||||
**NOTE:** The django server need not be up to run the tests, however the DB needs to be running.
|
||||
26
backend/account/ReadMe.md
Normal file
26
backend/account/ReadMe.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# Basic WorkFlow
|
||||
|
||||
`We can Add Workflows Here`
|
||||
|
||||
## Login
|
||||
|
||||
### Step
|
||||
|
||||
1. Login
|
||||
2. Get Organizations
|
||||
3. Set Organization
|
||||
4. Use organizational APIs /unstract/<org_id>/
|
||||
|
||||
## Switch organization
|
||||
|
||||
1. Get Organizations
|
||||
2. Set Organization
|
||||
3. Use organizational APIs /unstract/<org_id>/
|
||||
|
||||
## Get current user and Organization data
|
||||
|
||||
- Use Get User Profile and Get Organization Info APIs
|
||||
|
||||
## Signout
|
||||
|
||||
1.signout APi
|
||||
0
backend/account/__init__.py
Normal file
0
backend/account/__init__.py
Normal file
5
backend/account/admin.py
Normal file
5
backend/account/admin.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from django.contrib import admin
|
||||
|
||||
from .models import Organization, User
|
||||
|
||||
admin.site.register([Organization, User])
|
||||
202
backend/account/api_doc.md
Normal file
202
backend/account/api_doc.md
Normal file
@@ -0,0 +1,202 @@
|
||||
# Authentication APIs
|
||||
|
||||
### Unauthorized Attempt
|
||||
|
||||
**Status Code:** 401
|
||||
|
||||
**Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
"message": "Unauthorized"
|
||||
}
|
||||
```
|
||||
|
||||
### Login
|
||||
|
||||
```http
|
||||
GET /login
|
||||
```
|
||||
|
||||
Sample URL:
|
||||
> <base_url>/login
|
||||
|
||||
Response: Auth0 login page
|
||||
|
||||
### Logout
|
||||
|
||||
```http
|
||||
GET /logout
|
||||
```
|
||||
|
||||
Sample URL:
|
||||
> <base_url>/logout
|
||||
|
||||
Response: OK (redirect to Auth0 page)
|
||||
|
||||
### Signup
|
||||
|
||||
```http
|
||||
GET /signup
|
||||
```
|
||||
|
||||
Sample URL:
|
||||
> <base_url>/signup
|
||||
|
||||
Response: Auth0 signup page
|
||||
|
||||
### Get User Profile
|
||||
|
||||
```http
|
||||
GET /profile
|
||||
```
|
||||
|
||||
Sample URL:
|
||||
> <base_url>/unstract/<org_id>/profile
|
||||
|
||||
Response:
|
||||
|
||||
```json
|
||||
{
|
||||
"user": {
|
||||
"id": "6",
|
||||
"email": "iamali003@gmail.com",
|
||||
"name": "Ali",
|
||||
"display_name": "Ali",
|
||||
"family_name": null,
|
||||
"picture": null
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Password Reset
|
||||
|
||||
```http
|
||||
POST /reset_password
|
||||
```
|
||||
|
||||
Sample URL:
|
||||
> <base_url>/unstract/<org_id>/reset_password
|
||||
|
||||
Response:
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "failed",
|
||||
"message": "user doesn't have Username-Password-Authentication"
|
||||
}
|
||||
```
|
||||
|
||||
### Get Organizations of User
|
||||
|
||||
```http
|
||||
GET /organization
|
||||
```
|
||||
|
||||
Sample URL:
|
||||
> <base_url>/organization
|
||||
|
||||
Response:
|
||||
|
||||
```json
|
||||
{
|
||||
"message": "success",
|
||||
"organizations": [
|
||||
{
|
||||
"id": "org_Z12elHhCcPH5rPD7",
|
||||
"display_name": "Personal Org",
|
||||
"name": "personal"
|
||||
},
|
||||
{
|
||||
"id": "org_CB46CBskR8BxFjVV",
|
||||
"display_name": "ali Test",
|
||||
"name": "ali1"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Select an Organization to Use
|
||||
|
||||
```http
|
||||
POST /set
|
||||
```
|
||||
|
||||
Sample URL:
|
||||
> <base_url>/organization/<org_id>/set
|
||||
|
||||
Response:
|
||||
|
||||
```json
|
||||
{
|
||||
"user": {
|
||||
"id": "6",
|
||||
"email": "iamali003@gmail.com",
|
||||
"name": "Ali",
|
||||
"display_name": "Ali",
|
||||
"family_name": null,
|
||||
"picture": null
|
||||
},
|
||||
"organization": {
|
||||
"display_name": "ali Test",
|
||||
"name": "ali1",
|
||||
"organization_id": "org_CB46CBskR8BxFjVV"
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
### Get Organization Members
|
||||
|
||||
```http
|
||||
GET /members
|
||||
```
|
||||
|
||||
Sample URL:
|
||||
> <base_url>/unstract/<org_id>/members
|
||||
|
||||
Response:
|
||||
|
||||
```json
|
||||
{
|
||||
"message": "success",
|
||||
"members": [
|
||||
{
|
||||
"user_id": "google-oauth2|102763382532901780910",
|
||||
"email": "iamali003@gmail.com",
|
||||
"name": "",
|
||||
"picture": null
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Get Organization Info
|
||||
|
||||
```http
|
||||
GET /organization
|
||||
```
|
||||
|
||||
Sample URL:
|
||||
> <base_url>/unstract/<org_id>/organization
|
||||
|
||||
Response:
|
||||
|
||||
```json
|
||||
{
|
||||
"message": "success",
|
||||
"organization": {
|
||||
"name": "ali1",
|
||||
"display_name": "ali Test",
|
||||
"organization_id": "org_CB46CBskR8BxFjVV",
|
||||
"created_at": "2023-06-26 04:42:40.905458+00:00"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Ref
|
||||
|
||||
- Postman Collection : [postaman collection link](https://api.postman.com/collections/24537488-9380ea92-d1e0-45f4-827c-c5cc9d0370b8?access_key=PMAT-01H3VGHTM9SR01MHXA95G1RTWB)
|
||||
- By testing Postman
|
||||
- set cookies
|
||||
- set X-CSRFToken in header for POST request
|
||||
6
backend/account/apps.py
Normal file
6
backend/account/apps.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class AccountConfig(AppConfig):
|
||||
default_auto_field = "django.db.models.BigAutoField"
|
||||
name = "account"
|
||||
508
backend/account/authentication_controller.py
Normal file
508
backend/account/authentication_controller.py
Normal file
@@ -0,0 +1,508 @@
|
||||
import logging
|
||||
from typing import Any, Optional, Union
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from account.authentication_helper import AuthenticationHelper
|
||||
from account.authentication_plugin_registry import AuthenticationPluginRegistry
|
||||
from account.authentication_service import AuthenticationService
|
||||
from account.cache_service import CacheService
|
||||
from account.constants import (
|
||||
AuthoErrorCode,
|
||||
Cookie,
|
||||
ErrorMessage,
|
||||
OrganizationMemberModel,
|
||||
)
|
||||
from account.custom_exceptions import (
|
||||
DuplicateData,
|
||||
Forbidden,
|
||||
UserNotExistError,
|
||||
)
|
||||
from account.dto import (
|
||||
CallbackData,
|
||||
MemberInvitation,
|
||||
OrganizationData,
|
||||
UserInfo,
|
||||
UserInviteResponse,
|
||||
UserRoleData,
|
||||
)
|
||||
from account.exceptions import OrganizationNotExist
|
||||
from account.models import Organization, User
|
||||
from account.organization import OrganizationService
|
||||
from account.serializer import (
|
||||
GetOrganizationsResponseSerializer,
|
||||
OrganizationSerializer,
|
||||
SetOrganizationsResponseSerializer,
|
||||
)
|
||||
from account.user import UserService
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import login as django_login
|
||||
from django.contrib.auth import logout as django_logout
|
||||
from django.db.utils import IntegrityError
|
||||
from django.middleware import csrf
|
||||
from django.shortcuts import redirect
|
||||
from django_tenants.utils import tenant_context
|
||||
from psycopg2.errors import UndefinedTable
|
||||
from rest_framework import status
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from tenant_account.models import OrganizationMember as OrganizationMember
|
||||
from tenant_account.organization_member_service import OrganizationMemberService
|
||||
|
||||
Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthenticationController:
|
||||
"""Authentication Controller This controller class manages user
|
||||
authentication processes."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""This method initializes the controller by selecting the appropriate
|
||||
authentication plugin based on availability."""
|
||||
self.authentication_helper = AuthenticationHelper()
|
||||
self.organization_member_service = OrganizationMemberService()
|
||||
if AuthenticationPluginRegistry.is_plugin_available():
|
||||
self.auth_service: AuthenticationService = (
|
||||
AuthenticationPluginRegistry.get_plugin()
|
||||
)
|
||||
else:
|
||||
self.auth_service = AuthenticationService()
|
||||
|
||||
def user_login(
|
||||
self,
|
||||
request: Request,
|
||||
) -> Any:
|
||||
return self.auth_service.user_login(request)
|
||||
|
||||
def user_signup(self, request: Request) -> Any:
|
||||
return self.auth_service.user_signup(request)
|
||||
|
||||
def authorization_callback(
|
||||
self, request: Request, backend: str = settings.DEFAULT_MODEL_BACKEND
|
||||
) -> Any:
|
||||
"""Handle authorization callback.
|
||||
|
||||
This function processes the authorization callback from
|
||||
an external service.
|
||||
|
||||
Args:
|
||||
request (Request): Request instance
|
||||
backend (str, optional): backend used to use login.
|
||||
Defaults: settings.DEFAULT_MODEL_BACKEND.
|
||||
|
||||
Returns:
|
||||
Any: Redirect response
|
||||
"""
|
||||
|
||||
callback_data: CallbackData = self.auth_service.get_callback_data(
|
||||
request=request
|
||||
)
|
||||
user: User = self.get_or_create_user_by_email(request, callback_data)
|
||||
try:
|
||||
member = self.auth_service.handle_invited_user_while_callback(
|
||||
request=request, user=user
|
||||
)
|
||||
|
||||
except Exception as ex:
|
||||
"""Error code reference
|
||||
frontend/src/components/error/GenericError/GenericError.jsx."""
|
||||
if ex.code == AuthoErrorCode.IDM: # type: ignore
|
||||
query_params = {"code": AuthoErrorCode.IDM}
|
||||
return redirect(
|
||||
f"{settings.ERROR_URL}?{urlencode(query_params)}"
|
||||
)
|
||||
elif ex.code == AuthoErrorCode.UMM: # type: ignore
|
||||
query_params = {"code": AuthoErrorCode.UMM}
|
||||
return redirect(
|
||||
f"{settings.ERROR_URL}?{urlencode(query_params)}"
|
||||
)
|
||||
|
||||
return redirect(f"{settings.ERROR_URL}")
|
||||
|
||||
if member.organization_id and member.role and len(member.role) > 0:
|
||||
organization: Optional[Organization] = (
|
||||
OrganizationService.get_organization_by_org_id(
|
||||
member.organization_id
|
||||
)
|
||||
)
|
||||
if organization:
|
||||
try:
|
||||
self.create_tenant_user(
|
||||
organization=organization, user=user
|
||||
)
|
||||
except UndefinedTable:
|
||||
pass
|
||||
|
||||
response = self.auth_service.handle_authorization_callback(
|
||||
user=user,
|
||||
data=callback_data,
|
||||
redirect_url=request.GET.get("redirect_url"),
|
||||
)
|
||||
django_login(request, user, backend)
|
||||
|
||||
return response
|
||||
|
||||
def user_organizations(self, request: Request) -> Any:
|
||||
"""List a user's organizations.
|
||||
|
||||
Args:
|
||||
user (User): User instance
|
||||
z_code (str): _description_
|
||||
|
||||
Returns:
|
||||
list[OrganizationData]: _description_
|
||||
"""
|
||||
|
||||
try:
|
||||
organizations = self.auth_service.user_organizations(request)
|
||||
except Exception as ex:
|
||||
#
|
||||
self.user_logout(request)
|
||||
if ex.code == AuthoErrorCode.USF: # type: ignore
|
||||
response = Response(
|
||||
status=status.HTTP_412_PRECONDITION_FAILED,
|
||||
data={"domain": ex.data.get("domain")}, # type: ignore
|
||||
)
|
||||
return response
|
||||
user: User = request.user
|
||||
org_ids = {org.id for org in organizations}
|
||||
CacheService.set_user_organizations(user.user_id, list(org_ids))
|
||||
|
||||
serialized_organizations = GetOrganizationsResponseSerializer(
|
||||
organizations, many=True
|
||||
).data
|
||||
response = Response(
|
||||
status=status.HTTP_200_OK,
|
||||
data={
|
||||
"message": "success",
|
||||
"organizations": serialized_organizations,
|
||||
},
|
||||
)
|
||||
if Cookie.CSRFTOKEN not in request.COOKIES:
|
||||
csrf_token = csrf.get_token(request)
|
||||
response.set_cookie(Cookie.CSRFTOKEN, csrf_token)
|
||||
|
||||
return response
|
||||
|
||||
def set_user_organization(
|
||||
self, request: Request, organization_id: str
|
||||
) -> Response:
|
||||
user: User = request.user
|
||||
organization_ids = CacheService.get_user_organizations(user.user_id)
|
||||
if not organization_ids:
|
||||
z_organizations: list[OrganizationData] = (
|
||||
self.auth_service.get_organizations_by_user_id(user.user_id)
|
||||
)
|
||||
organization_ids = {org.id for org in z_organizations}
|
||||
if organization_id and organization_id in organization_ids:
|
||||
organization = OrganizationService.get_organization_by_org_id(
|
||||
organization_id
|
||||
)
|
||||
if not organization:
|
||||
try:
|
||||
organization_data: OrganizationData = (
|
||||
self.auth_service.get_organization_by_org_id(
|
||||
organization_id
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
raise OrganizationNotExist()
|
||||
try:
|
||||
organization = OrganizationService.create_organization(
|
||||
organization_data.name,
|
||||
organization_data.display_name,
|
||||
organization_data.id,
|
||||
)
|
||||
except IntegrityError:
|
||||
raise DuplicateData(
|
||||
f"{ErrorMessage.ORGANIZATION_EXIST}, \
|
||||
{ErrorMessage.DUPLICATE_API}"
|
||||
)
|
||||
self.create_tenant_user(organization=organization, user=user)
|
||||
user_info: Optional[UserInfo] = self.get_user_info(request)
|
||||
serialized_user_info = SetOrganizationsResponseSerializer(
|
||||
user_info
|
||||
).data
|
||||
organization_info = OrganizationSerializer(organization).data
|
||||
response: Response = Response(
|
||||
status=status.HTTP_200_OK,
|
||||
data={
|
||||
"user": serialized_user_info,
|
||||
"organization": organization_info,
|
||||
},
|
||||
)
|
||||
# Update user session data in redis
|
||||
user_session_info: dict[str, Any] = (
|
||||
CacheService.get_user_session_info(user.email)
|
||||
)
|
||||
user_session_info["current_org"] = organization_id
|
||||
CacheService.set_user_session_info(user_session_info)
|
||||
response.set_cookie(Cookie.ORG_ID, organization_id)
|
||||
return response
|
||||
return Response(status=status.HTTP_403_FORBIDDEN)
|
||||
|
||||
def get_user_info(self, request: Request) -> Optional[UserInfo]:
|
||||
return self.auth_service.get_user_info(request)
|
||||
|
||||
def is_admin_by_role(self, role: str) -> bool:
|
||||
"""Check the role is act as admin in the context of authentication
|
||||
plugin.
|
||||
|
||||
Args:
|
||||
role (str): role
|
||||
|
||||
Returns:
|
||||
bool: _description_
|
||||
"""
|
||||
return self.auth_service.is_admin_by_role(role=role)
|
||||
|
||||
def get_organization_info(self, org_id: str) -> Optional[Organization]:
|
||||
organization = self.auth_service.get_organization_info(org_id)
|
||||
if not organization:
|
||||
organization = OrganizationService.get_organization_by_org_id(
|
||||
org_id=org_id
|
||||
)
|
||||
return organization
|
||||
|
||||
def make_organization_and_add_member(
|
||||
self,
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
organization_name: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
) -> Optional[OrganizationData]:
|
||||
return self.auth_service.make_organization_and_add_member(
|
||||
user_id, user_name, organization_name, display_name
|
||||
)
|
||||
|
||||
def make_user_organization_name(self) -> str:
|
||||
return self.auth_service.make_user_organization_name()
|
||||
|
||||
def make_user_organization_display_name(self, user_name: str) -> str:
|
||||
return self.auth_service.make_user_organization_display_name(user_name)
|
||||
|
||||
def user_logout(self, request: Request) -> Response:
|
||||
response = self.auth_service.user_logout(request=request)
|
||||
django_logout(request)
|
||||
return response
|
||||
|
||||
def get_organization_members_by_org_id(
|
||||
self, organization_id: Optional[str] = None
|
||||
) -> list[OrganizationMember]:
|
||||
members: list[OrganizationMember] = OrganizationMember.objects.all()
|
||||
return members
|
||||
|
||||
def get_organization_members_by_user(
|
||||
self, user: User
|
||||
) -> OrganizationMember:
|
||||
member: OrganizationMember = OrganizationMember.objects.filter(
|
||||
user=user
|
||||
).first()
|
||||
return member
|
||||
|
||||
def get_user_roles(self) -> list[UserRoleData]:
|
||||
return self.auth_service.get_roles()
|
||||
|
||||
def get_user_invitations(
|
||||
self, organization_id: str
|
||||
) -> list[MemberInvitation]:
|
||||
return self.auth_service.get_invitations(
|
||||
organization_id=organization_id
|
||||
)
|
||||
|
||||
def delete_user_invitation(
|
||||
self, organization_id: str, invitation_id: str
|
||||
) -> bool:
|
||||
return self.auth_service.delete_invitation(
|
||||
organization_id=organization_id, invitation_id=invitation_id
|
||||
)
|
||||
|
||||
def reset_user_password(self, user: User) -> Response:
|
||||
return self.auth_service.reset_user_password(user)
|
||||
|
||||
def invite_user(
|
||||
self,
|
||||
admin: User,
|
||||
org_id: str,
|
||||
user_list: list[dict[str, Union[str, None]]],
|
||||
) -> list[UserInviteResponse]:
|
||||
"""Invites users to join an organization.
|
||||
|
||||
Args:
|
||||
admin (User): Admin user initiating the invitation.
|
||||
org_id (str): ID of the organization to which users are invited.
|
||||
user_list (list[dict[str, Union[str, None]]]):
|
||||
List of user details for invitation.
|
||||
Returns:
|
||||
list[UserInviteResponse]: List of responses for each
|
||||
user invitation.
|
||||
"""
|
||||
admin_user = OrganizationMember.objects.get(user=admin.id)
|
||||
if not self.auth_service.is_organization_admin(admin_user):
|
||||
raise Forbidden()
|
||||
response = []
|
||||
for user_item in user_list:
|
||||
email = user_item.get("email")
|
||||
role = user_item.get("role")
|
||||
if email:
|
||||
user = self.organization_member_service.get_user_by_email(
|
||||
email=email
|
||||
)
|
||||
user_response = {}
|
||||
user_response["email"] = email
|
||||
message: str = "User invitation successful"
|
||||
if user:
|
||||
# Already in organization
|
||||
status = False
|
||||
message = "User is already part of current organization."
|
||||
else:
|
||||
try:
|
||||
self.auth_service.check_user_organization_association(
|
||||
user_email=email
|
||||
)
|
||||
status = self.auth_service.invite_user(
|
||||
admin_user, org_id, email, role=role
|
||||
)
|
||||
except Exception as exception:
|
||||
status = False
|
||||
message = exception.message # type: ignore
|
||||
response.append(
|
||||
UserInviteResponse(
|
||||
email=email,
|
||||
status="success" if status else "failed",
|
||||
message=message,
|
||||
)
|
||||
)
|
||||
return response
|
||||
|
||||
def remove_users_from_organization(
|
||||
self, admin: User, organization_id: str, user_emails: list[str]
|
||||
) -> bool:
|
||||
admin_user = OrganizationMember.objects.get(user=admin.id)
|
||||
user_ids = OrganizationMember.objects.filter(
|
||||
user__email__in=user_emails
|
||||
).values_list(
|
||||
OrganizationMemberModel.USER_ID, OrganizationMemberModel.ID
|
||||
)
|
||||
user_ids_list: list[str] = []
|
||||
ids_list: list[str] = []
|
||||
for user in user_ids:
|
||||
user_ids_list.append(user[0])
|
||||
ids_list.append(user[1])
|
||||
if len(user_ids_list) > 0:
|
||||
is_removed = self.auth_service.remove_users_from_organization(
|
||||
admin=admin_user,
|
||||
organization_id=organization_id,
|
||||
user_ids=user_ids_list,
|
||||
)
|
||||
else:
|
||||
is_removed = False
|
||||
if is_removed:
|
||||
OrganizationMember.objects.filter(user__in=ids_list).delete()
|
||||
return is_removed
|
||||
|
||||
def add_user_role(
|
||||
self, admin: User, org_id: str, email: str, role: str
|
||||
) -> Optional[str]:
|
||||
admin_user = OrganizationMember.objects.get(user=admin.id)
|
||||
user = self.organization_member_service.get_user_by_email(email=email)
|
||||
if user:
|
||||
current_roles = self.auth_service.add_organization_user_role(
|
||||
admin_user, org_id, user.user.user_id, [role]
|
||||
)
|
||||
if current_roles:
|
||||
self.save_orgnanization_user_role(
|
||||
user_id=user.user.user_id, role=current_roles[0]
|
||||
)
|
||||
return current_roles[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def remove_user_role(
|
||||
self, admin: User, org_id: str, email: str, role: str
|
||||
) -> Optional[str]:
|
||||
admin_user = OrganizationMember.objects.get(user=admin.id)
|
||||
organization_member = (
|
||||
self.organization_member_service.get_user_by_email(email=email)
|
||||
)
|
||||
if organization_member:
|
||||
current_roles = self.auth_service.remove_organization_user_role(
|
||||
admin_user, org_id, organization_member.user.user_id, [role]
|
||||
)
|
||||
if current_roles:
|
||||
self.save_orgnanization_user_role(
|
||||
user_id=organization_member.user.user_id,
|
||||
role=current_roles[0],
|
||||
)
|
||||
return current_roles[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def save_orgnanization_user_role(self, user_id: str, role: str) -> None:
|
||||
organization_user = (
|
||||
self.organization_member_service.get_user_by_user_id(
|
||||
user_id=user_id
|
||||
)
|
||||
)
|
||||
if organization_user:
|
||||
# consider single role
|
||||
organization_user.role = role
|
||||
organization_user.save()
|
||||
|
||||
def create_tenant_user(
|
||||
self, organization: Organization, user: User
|
||||
) -> None:
|
||||
with tenant_context(organization):
|
||||
existing_tenant_user = (
|
||||
self.organization_member_service.get_user_by_id(id=user.id)
|
||||
)
|
||||
if existing_tenant_user:
|
||||
Logger.info(f"{existing_tenant_user.user.email} Already exist")
|
||||
else:
|
||||
account_user = self.get_or_create_user(user=user)
|
||||
if account_user:
|
||||
user_roles = (
|
||||
self.auth_service.get_organization_role_of_user(
|
||||
user_id=account_user.user_id,
|
||||
organization_id=organization.organization_id,
|
||||
)
|
||||
)
|
||||
user_role = user_roles[0]
|
||||
tenant_user: OrganizationMember = OrganizationMember(
|
||||
user=user, role=user_role
|
||||
)
|
||||
tenant_user.save()
|
||||
else:
|
||||
raise UserNotExistError()
|
||||
|
||||
def get_or_create_user_by_email(
|
||||
self, request: Request, callback_data: CallbackData
|
||||
) -> Union[User, OrganizationMember]:
|
||||
email = callback_data.email
|
||||
user_service = UserService()
|
||||
user = user_service.get_user_by_email(email)
|
||||
if not user:
|
||||
user_id = callback_data.user_id
|
||||
user = user_service.create_user(email, user_id)
|
||||
return user
|
||||
|
||||
def get_or_create_user(
|
||||
self, user: User
|
||||
) -> Optional[Union[User, OrganizationMember]]:
|
||||
user_service = UserService()
|
||||
if user.id:
|
||||
account_user: Optional[User] = user_service.get_user_by_id(user.id)
|
||||
if account_user:
|
||||
return account_user
|
||||
elif user.email:
|
||||
account_user = user_service.get_user_by_email(email=user.email)
|
||||
if account_user:
|
||||
return account_user
|
||||
if user.user_id:
|
||||
user.save()
|
||||
return user
|
||||
elif user.email and user.user_id:
|
||||
account_user = user_service.create_user(
|
||||
email=user.email, user_id=user.user_id
|
||||
)
|
||||
return account_user
|
||||
return None
|
||||
38
backend/account/authentication_helper.py
Normal file
38
backend/account/authentication_helper.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from typing import Any
|
||||
|
||||
from account.constants import DefaultOrg
|
||||
from account.dto import CallbackData, MemberData, OrganizationData
|
||||
from rest_framework.request import Request
|
||||
|
||||
|
||||
class AuthenticationHelper:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_organizations_by_user_id(self) -> list[OrganizationData]:
|
||||
organizationData: OrganizationData = OrganizationData(
|
||||
id=DefaultOrg.MOCK_ORG,
|
||||
display_name=DefaultOrg.MOCK_ORG,
|
||||
name=DefaultOrg.MOCK_ORG,
|
||||
)
|
||||
return [organizationData]
|
||||
|
||||
def get_authorize_token(rself, equest: Request) -> CallbackData:
|
||||
return CallbackData(
|
||||
user_id=DefaultOrg.MOCK_USER_ID,
|
||||
email=DefaultOrg.MOCK_USER_EMAIL,
|
||||
token="",
|
||||
)
|
||||
|
||||
def list_of_members_from_user_model(
|
||||
self, model_data: list[Any]
|
||||
) -> list[MemberData]:
|
||||
members: list[MemberData] = []
|
||||
for data in model_data:
|
||||
user_id = data.user_id
|
||||
email = data.email
|
||||
name = data.username
|
||||
|
||||
members.append(MemberData(user_id=user_id, email=email, name=name))
|
||||
|
||||
return members
|
||||
98
backend/account/authentication_plugin_registry.py
Normal file
98
backend/account/authentication_plugin_registry.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import logging
|
||||
import os
|
||||
from importlib import import_module
|
||||
from typing import Any
|
||||
|
||||
from account.constants import PluginConfig
|
||||
from django.apps import apps
|
||||
|
||||
Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _load_plugins() -> dict[str, dict[str, Any]]:
|
||||
"""Iterating through the Authentication plugins and register their
|
||||
metadata."""
|
||||
auth_app = apps.get_app_config(PluginConfig.PLUGINS_APP)
|
||||
auth_package_path = auth_app.module.__package__
|
||||
auth_dir = os.path.join(auth_app.path, PluginConfig.AUTH_PLUGIN_DIR)
|
||||
auth_package_path = f"{auth_package_path}.{PluginConfig.AUTH_PLUGIN_DIR}"
|
||||
auth_modules = {}
|
||||
|
||||
for item in os.listdir(auth_dir):
|
||||
# Loads a plugin only if name starts with `auth`.
|
||||
if not item.startswith(PluginConfig.AUTH_MODULE_PREFIX):
|
||||
continue
|
||||
# Loads a plugin if it is in a directory.
|
||||
if os.path.isdir(os.path.join(auth_dir, item)):
|
||||
auth_module_name = item
|
||||
# Loads a plugin if it is a shared library.
|
||||
# Module name is extracted from shared library name.
|
||||
# `auth.platform_architecture.so` will be file name and
|
||||
# `auth` will be the module name.
|
||||
elif item.endswith(".so"):
|
||||
auth_module_name = item.split(".")[0]
|
||||
else:
|
||||
continue
|
||||
try:
|
||||
full_module_path = f"{auth_package_path}.{auth_module_name}"
|
||||
module = import_module(full_module_path)
|
||||
metadata = getattr(module, PluginConfig.AUTH_METADATA, {})
|
||||
if metadata.get(PluginConfig.METADATA_IS_ACTIVE, False):
|
||||
auth_modules[auth_module_name] = {
|
||||
PluginConfig.AUTH_MODULE: module,
|
||||
PluginConfig.AUTH_METADATA: module.metadata,
|
||||
}
|
||||
Logger.info(
|
||||
"Loaded auth plugin: %s, is_active: %s",
|
||||
module.metadata["name"],
|
||||
module.metadata["is_active"],
|
||||
)
|
||||
else:
|
||||
Logger.warning(
|
||||
"Metadata is not active for %s authentication module.",
|
||||
auth_module_name,
|
||||
)
|
||||
except ModuleNotFoundError as exception:
|
||||
Logger.error(
|
||||
"Error while importing authentication module : %s",
|
||||
exception,
|
||||
)
|
||||
|
||||
if len(auth_modules) > 1:
|
||||
raise ValueError(
|
||||
"Multiple authentication modules found."
|
||||
"Only one authentication method is allowed."
|
||||
)
|
||||
elif len(auth_modules) == 0:
|
||||
Logger.warning(
|
||||
"No authentication modules found."
|
||||
"Application will start without authentication module"
|
||||
)
|
||||
return auth_modules
|
||||
|
||||
|
||||
class AuthenticationPluginRegistry:
|
||||
auth_modules: dict[str, dict[str, Any]] = _load_plugins()
|
||||
|
||||
@classmethod
|
||||
def is_plugin_available(cls) -> bool:
|
||||
"""Check if any authentication plugin is available.
|
||||
|
||||
Returns:
|
||||
bool: True if a plugin is available, False otherwise.
|
||||
"""
|
||||
return len(cls.auth_modules) > 0
|
||||
|
||||
@classmethod
|
||||
def get_plugin(cls) -> Any:
|
||||
"""Get the selected authentication plugin.
|
||||
|
||||
Returns:
|
||||
AuthenticationService: Selected authentication plugin instance.
|
||||
"""
|
||||
chosen_auth_module = next(iter(cls.auth_modules.values()))
|
||||
chosen_metadata = chosen_auth_module[PluginConfig.AUTH_METADATA]
|
||||
service_class_name = chosen_metadata[
|
||||
PluginConfig.METADATA_SERVICE_CLASS
|
||||
]
|
||||
return service_class_name()
|
||||
326
backend/account/authentication_service.py
Normal file
326
backend/account/authentication_service.py
Normal file
@@ -0,0 +1,326 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
from account.authentication_helper import AuthenticationHelper
|
||||
from account.cache_service import CacheService
|
||||
from account.constants import Common, DefaultOrg
|
||||
from account.custom_exceptions import Forbidden, MethodNotImplemented
|
||||
from account.dto import (
|
||||
CallbackData,
|
||||
MemberData,
|
||||
MemberInvitation,
|
||||
OrganizationData,
|
||||
ResetUserPasswordDto,
|
||||
UserInfo,
|
||||
UserRoleData,
|
||||
UserSessionInfo,
|
||||
)
|
||||
from account.enums import UserRole
|
||||
from account.models import Organization, User
|
||||
from account.organization import OrganizationService
|
||||
from django.http import HttpRequest
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from tenant_account.models import OrganizationMember as OrganizationMember
|
||||
|
||||
Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthenticationService:
|
||||
def __init__(self) -> None:
|
||||
self.authentication_helper = AuthenticationHelper()
|
||||
self.default_user: User = self.get_user()
|
||||
self.default_organization: Organization = self.user_organization()
|
||||
self.user_session_info = self.get_user_session_info()
|
||||
|
||||
def get_current_organization(self) -> Organization:
|
||||
return self.default_organization
|
||||
|
||||
def get_current_user(self) -> User:
|
||||
return self.default_user
|
||||
|
||||
def get_current_user_session(self) -> UserSessionInfo:
|
||||
return self.user_session_info
|
||||
|
||||
def user_login(self, request: HttpRequest) -> Any:
|
||||
raise MethodNotImplemented()
|
||||
|
||||
def user_signup(self, request: HttpRequest) -> Any:
|
||||
raise MethodNotImplemented()
|
||||
|
||||
def is_admin_by_role(self, role: str) -> bool:
|
||||
"""Check the role with actual admin Role.
|
||||
|
||||
Args:
|
||||
role (str): input string
|
||||
|
||||
Returns:
|
||||
bool: _description_
|
||||
"""
|
||||
try:
|
||||
return UserRole(role.lower()) == UserRole.ADMIN
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def get_callback_data(self, request: Request) -> CallbackData:
|
||||
return CallbackData(
|
||||
user_id=DefaultOrg.MOCK_USER_ID,
|
||||
email=DefaultOrg.MOCK_USER_EMAIL,
|
||||
token="",
|
||||
)
|
||||
|
||||
def user_organization(self) -> Organization:
|
||||
return Organization(
|
||||
name=DefaultOrg.ORGANIZATION_NAME,
|
||||
display_name=DefaultOrg.ORGANIZATION_NAME,
|
||||
organization_id=DefaultOrg.ORGANIZATION_NAME,
|
||||
schema_name=DefaultOrg.ORGANIZATION_NAME,
|
||||
)
|
||||
|
||||
def handle_invited_user_while_callback(
|
||||
self, request: Request, user: User
|
||||
) -> MemberData:
|
||||
member_data: MemberData = MemberData(
|
||||
user_id=self.default_user.user_id,
|
||||
organization_id=self.default_organization.organization_id,
|
||||
role=[UserRole.ADMIN.value],
|
||||
)
|
||||
|
||||
return member_data
|
||||
|
||||
def handle_authorization_callback(
|
||||
self, user: User, data: CallbackData, redirect_url: str = ""
|
||||
) -> Response:
|
||||
return Response()
|
||||
|
||||
def add_to_organization(
|
||||
self,
|
||||
request: Request,
|
||||
user: User,
|
||||
data: Optional[dict[str, Any]] = None,
|
||||
) -> MemberData:
|
||||
member_data: MemberData = MemberData(
|
||||
user_id=self.default_user.user_id,
|
||||
organization_id=self.default_organization.organization_id,
|
||||
)
|
||||
|
||||
return member_data
|
||||
|
||||
def remove_users_from_organization(
|
||||
self,
|
||||
admin: OrganizationMember,
|
||||
organization_id: str,
|
||||
user_ids: list[str],
|
||||
) -> bool:
|
||||
raise MethodNotImplemented()
|
||||
|
||||
def user_organizations(self, request: Request) -> list[OrganizationData]:
|
||||
organizationData: OrganizationData = OrganizationData(
|
||||
id=self.default_organization.organization_id,
|
||||
display_name=self.default_organization.display_name,
|
||||
name=self.default_organization.name,
|
||||
)
|
||||
return [organizationData]
|
||||
|
||||
def get_organizations_by_user_id(self, id: str) -> list[OrganizationData]:
|
||||
organizationData: OrganizationData = OrganizationData(
|
||||
id=self.default_organization.organization_id,
|
||||
display_name=self.default_organization.display_name,
|
||||
name=self.default_organization.name,
|
||||
)
|
||||
return [organizationData]
|
||||
|
||||
def get_organization_role_of_user(
|
||||
self, user_id: str, organization_id: str
|
||||
) -> list[str]:
|
||||
return [UserRole.ADMIN.value]
|
||||
|
||||
def is_organization_admin(self, member: OrganizationMember) -> bool:
|
||||
"""Check if the organization member has administrative privileges.
|
||||
|
||||
Args:
|
||||
member (OrganizationMember): The organization member to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the user has administrative privileges,
|
||||
False otherwise.
|
||||
"""
|
||||
try:
|
||||
return UserRole(member.role) == UserRole.ADMIN
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def check_user_organization_association(self, user_email: str) -> None:
|
||||
"""Check if the user is already associated with any organizations.
|
||||
|
||||
Raises:
|
||||
- UserAlreadyAssociatedException:
|
||||
If the user is already associated with organizations.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_roles(self) -> list[UserRoleData]:
|
||||
return [
|
||||
UserRoleData(name=UserRole.ADMIN.value),
|
||||
UserRoleData(name=UserRole.USER.value),
|
||||
]
|
||||
|
||||
def get_invitations(self, organization_id: str) -> list[MemberInvitation]:
|
||||
raise MethodNotImplemented()
|
||||
|
||||
def delete_invitation(
|
||||
self, organization_id: str, invitation_id: str
|
||||
) -> bool:
|
||||
raise MethodNotImplemented()
|
||||
|
||||
def add_organization_user_role(
|
||||
self,
|
||||
admin: User,
|
||||
organization_id: str,
|
||||
user_id: str,
|
||||
role_ids: list[str],
|
||||
) -> list[str]:
|
||||
if admin.role == UserRole.ADMIN.value:
|
||||
return role_ids
|
||||
raise Forbidden
|
||||
|
||||
def remove_organization_user_role(
|
||||
self,
|
||||
admin: User,
|
||||
organization_id: str,
|
||||
user_id: str,
|
||||
role_ids: list[str],
|
||||
) -> list[str]:
|
||||
if admin.role == UserRole.ADMIN.value:
|
||||
return role_ids
|
||||
raise Forbidden
|
||||
|
||||
def get_organization_by_org_id(self, id: str) -> OrganizationData:
|
||||
organizationData: OrganizationData = OrganizationData(
|
||||
id=DefaultOrg.ORGANIZATION_NAME,
|
||||
display_name=DefaultOrg.ORGANIZATION_NAME,
|
||||
name=DefaultOrg.ORGANIZATION_NAME,
|
||||
)
|
||||
return organizationData
|
||||
|
||||
def get_user(self) -> User:
|
||||
user = CacheService.get_user_session_info(DefaultOrg.MOCK_USER_EMAIL)
|
||||
if not user:
|
||||
try:
|
||||
user = User.objects.get(email=DefaultOrg.MOCK_USER_EMAIL)
|
||||
except User.DoesNotExist:
|
||||
user = User(
|
||||
username=DefaultOrg.MOCK_USER,
|
||||
user_id=DefaultOrg.MOCK_USER_ID,
|
||||
email=DefaultOrg.MOCK_USER_EMAIL,
|
||||
)
|
||||
user.save()
|
||||
if isinstance(user, User):
|
||||
id = user.id
|
||||
user_id = user.user_id
|
||||
email = user.email
|
||||
else:
|
||||
id = user[Common.ID]
|
||||
user_id = user[Common.USER_ID]
|
||||
email = user[Common.USER_EMAIL]
|
||||
|
||||
current_org = Common.PUBLIC_SCHEMA_NAME
|
||||
|
||||
user_session_info: UserSessionInfo = UserSessionInfo(
|
||||
id=id,
|
||||
user_id=user_id,
|
||||
email=email,
|
||||
current_org=current_org,
|
||||
)
|
||||
CacheService.set_user_session_info(user_session_info)
|
||||
user_info = User(id=id, user_id=user_id, username=email, email=email)
|
||||
return user_info
|
||||
|
||||
def get_user_info(self, request: Request) -> Optional[UserInfo]:
|
||||
user: User = request.user
|
||||
if user:
|
||||
return UserInfo(
|
||||
id=user.id,
|
||||
user_id=user.user_id,
|
||||
name=user.username,
|
||||
display_name=user.username,
|
||||
email=user.email,
|
||||
)
|
||||
else:
|
||||
user = self.get_user()
|
||||
return UserInfo(
|
||||
id=user.id,
|
||||
user_id=user.user_id,
|
||||
name=user.username,
|
||||
display_name=user.username,
|
||||
email=user.email,
|
||||
)
|
||||
|
||||
def get_user_session_info(self) -> UserSessionInfo:
|
||||
user_session_info_dict = CacheService.get_user_session_info(
|
||||
self.default_user.email
|
||||
)
|
||||
if not user_session_info_dict:
|
||||
user_session_info: UserSessionInfo = UserSessionInfo(
|
||||
id=self.default_user.id,
|
||||
user_id=self.default_user.user_id,
|
||||
email=self.default_user.email,
|
||||
current_org=self.default_organization.organization_id,
|
||||
)
|
||||
CacheService.set_user_session_info(user_session_info)
|
||||
else:
|
||||
user_session_info = UserSessionInfo.from_dict(
|
||||
user_session_info_dict
|
||||
)
|
||||
return user_session_info
|
||||
|
||||
def get_organization_info(self, org_id: str) -> Optional[Organization]:
|
||||
return OrganizationService.get_organization_by_org_id(org_id=org_id)
|
||||
|
||||
def make_organization_and_add_member(
|
||||
self,
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
organization_name: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
) -> Optional[OrganizationData]:
|
||||
organization: OrganizationData = OrganizationData(
|
||||
id=str(uuid.uuid4()),
|
||||
display_name=DefaultOrg.MOCK_ORG,
|
||||
name=DefaultOrg.MOCK_ORG,
|
||||
)
|
||||
return organization
|
||||
|
||||
def make_user_organization_name(self) -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def make_user_organization_display_name(self, user_name: str) -> str:
|
||||
name = f"{user_name}'s" if user_name else "Your"
|
||||
return f"{name} organization"
|
||||
|
||||
def user_logout(self, request: HttpRequest) -> Response:
|
||||
raise MethodNotImplemented()
|
||||
|
||||
def get_user_id_from_token(
|
||||
self, token: Optional[dict[str, Any]]
|
||||
) -> Response:
|
||||
return DefaultOrg.MOCK_USER_ID
|
||||
|
||||
def get_organization_members_by_org_id(
|
||||
self, organization_id: str
|
||||
) -> list[MemberData]:
|
||||
users: list[OrganizationMember] = OrganizationMember.objects.all()
|
||||
return self.authentication_helper.list_of_members_from_user_model(users)
|
||||
|
||||
def reset_user_password(self, user: User) -> ResetUserPasswordDto:
|
||||
raise MethodNotImplemented()
|
||||
|
||||
def invite_user(
|
||||
self,
|
||||
admin: OrganizationMember,
|
||||
org_id: str,
|
||||
email: str,
|
||||
role: Optional[str] = None,
|
||||
) -> bool:
|
||||
raise MethodNotImplemented()
|
||||
134
backend/account/cache_service.py
Normal file
134
backend/account/cache_service.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import redis
|
||||
from account.custom_cache import CustomCache
|
||||
from account.dto import UserSessionInfo
|
||||
from django.conf import settings
|
||||
from django.core.cache import cache
|
||||
|
||||
|
||||
class CacheService:
|
||||
def __init__(self) -> None:
|
||||
self.cache = redis.Redis(
|
||||
host=settings.REDIS_HOST,
|
||||
port=int(settings.REDIS_PORT),
|
||||
password=settings.REDIS_PASSWORD,
|
||||
username=settings.REDIS_USER,
|
||||
)
|
||||
|
||||
def get_a_key(self, key: str) -> Optional[Any]:
|
||||
data = self.cache.get(str(key))
|
||||
if data is not None:
|
||||
return data.decode("utf-8")
|
||||
return data
|
||||
|
||||
def set_a_key(self, key: str, value: Any) -> None:
|
||||
self.cache.set(
|
||||
str(key),
|
||||
value,
|
||||
int(settings.WORKFLOW_ACTION_EXPIRATION_TIME_IN_SECOND),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def set_cookie(cookie: str, token: dict[str, Any]) -> None:
|
||||
cache.set(cookie, token)
|
||||
|
||||
@staticmethod
|
||||
def get_cookie(cookie: str) -> dict[str, Any]:
|
||||
data: dict[str, Any] = cache.get(cookie)
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def set_user_session_info(
|
||||
user_session_info: Union[UserSessionInfo, dict[str, Any]]
|
||||
) -> None:
|
||||
if isinstance(user_session_info, UserSessionInfo):
|
||||
email = user_session_info.email
|
||||
user_session = user_session_info.to_dict()
|
||||
else:
|
||||
email = user_session_info["email"]
|
||||
user_session = user_session_info
|
||||
session_info_key: str = CacheService.get_user_session_info_key(email)
|
||||
cache.set(
|
||||
session_info_key,
|
||||
user_session,
|
||||
int(settings.SESSION_EXPIRATION_TIME_IN_SECOND),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_session_info(email: str) -> dict[str, Any]:
|
||||
session_info_key: str = CacheService.get_user_session_info_key(email)
|
||||
data: dict[str, Any] = cache.get(session_info_key)
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def get_user_session_info_key(email: str) -> str:
|
||||
session_info_key: str = f"session:{email}"
|
||||
return session_info_key
|
||||
|
||||
@staticmethod
|
||||
def check_a_key_exist(key: str, version: Any = None) -> bool:
|
||||
data: bool = cache.has_key(key, version)
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def delete_a_key(key: str, version: Any = None) -> None:
|
||||
cache.delete(key, version)
|
||||
|
||||
@staticmethod
|
||||
def set_user_organizations(user_id: str, organizations: list[str]) -> None:
|
||||
key: str = f"{user_id}|organizations"
|
||||
cache.set(key, list(organizations))
|
||||
|
||||
@staticmethod
|
||||
def get_user_organizations(user_id: str) -> Any:
|
||||
key: str = f"{user_id}|organizations"
|
||||
return cache.get(key)
|
||||
|
||||
@staticmethod
|
||||
def remove_user_organizations(user_id: str) -> Any:
|
||||
key: str = f"{user_id}|organizations"
|
||||
return cache.delete(key)
|
||||
|
||||
@staticmethod
|
||||
def add_cookie_id_to_user(user_id: str, cookie_id: str) -> None:
|
||||
custom_cache = CustomCache()
|
||||
key: str = f"{user_id}|cookies"
|
||||
custom_cache.rpush(key, cookie_id)
|
||||
|
||||
@staticmethod
|
||||
def remove_cookie_id_from_user(user_id: str, cookie_id: str) -> None:
|
||||
custom_cache = CustomCache()
|
||||
key: str = f"{user_id}|cookies"
|
||||
custom_cache.lrem(key, cookie_id)
|
||||
|
||||
@staticmethod
|
||||
def remove_all_session_keys(
|
||||
user_id: Optional[str] = None,
|
||||
cookie_id: Optional[str] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> None:
|
||||
if cookie_id is not None:
|
||||
cache.delete(cookie_id)
|
||||
if user_id is not None:
|
||||
cache.delete(user_id)
|
||||
CacheService.remove_user_organizations(user_id)
|
||||
if key is not None:
|
||||
cache.delete(key)
|
||||
|
||||
# @staticmethod
|
||||
# def get_cookie_ids_for_user(user_id: str) -> list[str]:
|
||||
# custom_cache = CustomCache()
|
||||
# key: str = f"{user_id}|cookies"
|
||||
# cookie_ids = custom_cache.lrange(key, 0, -1)
|
||||
# return cookie_ids
|
||||
|
||||
|
||||
# KEY_FUNCTION for cache settings
|
||||
def custom_key_function(key: str, key_prefix: str, version: int) -> str:
|
||||
if version > 1:
|
||||
return f"{key_prefix}:{version}:{key}"
|
||||
if key_prefix:
|
||||
return f"{key_prefix}:{key}"
|
||||
else:
|
||||
return key
|
||||
74
backend/account/constants.py
Normal file
74
backend/account/constants.py
Normal file
@@ -0,0 +1,74 @@
|
||||
class OAuthConstant:
|
||||
TOKEN_USER_INFO_FEILD = "userinfo"
|
||||
TOKEN_ORG_ID_FEILD = "org_id"
|
||||
TOKEN_EMAIL_FEILD = "email"
|
||||
TOKEN_Z_ID_FEILD = "sub"
|
||||
TOKEN_USER_NAME_FEILD = "name"
|
||||
TOKEN_PRIMARY_Z_ID_FEILD = "primary_sub"
|
||||
|
||||
|
||||
class LoginConstant:
|
||||
INVITATION = "invitation"
|
||||
ORGANIZATION = "organization"
|
||||
ORGANIZATION_NAME = "organization_name"
|
||||
|
||||
|
||||
class Common:
|
||||
NEXT_URL_VARIABLE = "next"
|
||||
PUBLIC_SCHEMA_NAME = "public"
|
||||
ID = "id"
|
||||
USER_ID = "user_id"
|
||||
USER_EMAIL = "email"
|
||||
USER_EMAILS = "emails"
|
||||
USER_IDS = "user_ids"
|
||||
USER_ROLE = "role"
|
||||
MAX_EMAIL_IN_REQUEST = 10
|
||||
|
||||
|
||||
class UserModel:
|
||||
USER_ID = "user_id"
|
||||
ID = "id"
|
||||
|
||||
|
||||
class OrganizationMemberModel:
|
||||
USER_ID = "user__user_id"
|
||||
ID = "user__id"
|
||||
|
||||
|
||||
class Cookie:
|
||||
ORG_ID = "org_id"
|
||||
Z_CODE = "z_code"
|
||||
CSRFTOKEN = "csrftoken"
|
||||
|
||||
|
||||
class ErrorMessage:
|
||||
ORGANIZATION_EXIST = "Organization already exists"
|
||||
DUPLICATE_API = "It appears that a duplicate call may have been made."
|
||||
|
||||
|
||||
class DefaultOrg:
|
||||
ORGANIZATION_NAME = "mock_org"
|
||||
MOCK_ORG = "mock_org"
|
||||
MOCK_USER = "mock_user"
|
||||
MOCK_USER_ID = "mock_user_id"
|
||||
MOCK_USER_EMAIL = "email@mock.com"
|
||||
|
||||
|
||||
class PluginConfig:
|
||||
PLUGINS_APP = "plugins"
|
||||
AUTH_MODULE_PREFIX = "auth"
|
||||
AUTH_PLUGIN_DIR = "authentication"
|
||||
AUTH_MODULE = "module"
|
||||
AUTH_METADATA = "metadata"
|
||||
METADATA_SERVICE_CLASS = "service_class"
|
||||
METADATA_IS_ACTIVE = "is_active"
|
||||
|
||||
|
||||
class AuthoErrorCode:
|
||||
"""Error code reference
|
||||
frontend/src/components/error/GenericError/GenericError.jsx."""
|
||||
|
||||
IDM = "IDM"
|
||||
UMM = "UMM"
|
||||
INF = "INF"
|
||||
USF = "USF"
|
||||
144
backend/account/custom_auth_middleware.py
Normal file
144
backend/account/custom_auth_middleware.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from typing import Optional
|
||||
|
||||
from account.authentication_plugin_registry import AuthenticationPluginRegistry
|
||||
from account.authentication_service import AuthenticationService
|
||||
from account.cache_service import CacheService
|
||||
from account.constants import Common, Cookie, DefaultOrg
|
||||
from account.dto import UserSessionInfo
|
||||
from account.models import User
|
||||
from account.user import UserService
|
||||
from backend.constants import RequestHeader
|
||||
from django.conf import settings
|
||||
from django.core.cache import cache
|
||||
from django.db import connection
|
||||
from django.http import HttpRequest, HttpResponse, JsonResponse
|
||||
from tenant_account.organization_member_service import OrganizationMemberService
|
||||
|
||||
|
||||
class CustomAuthMiddleware:
|
||||
def __init__(self, get_response: HttpResponse):
|
||||
self.get_response = get_response
|
||||
# One-time configuration and initialization.
|
||||
|
||||
def __call__(self, request: HttpRequest) -> HttpResponse:
|
||||
user = None
|
||||
request.user = user
|
||||
|
||||
# Returns result without authenticated if added in whitelisted paths
|
||||
if any(
|
||||
request.path.startswith(path) for path in settings.WHITELISTED_PATHS
|
||||
):
|
||||
return self.get_response(request)
|
||||
|
||||
tenantAccessiblePublicPath = False
|
||||
if any(
|
||||
request.path.startswith(path)
|
||||
for path in settings.TENANT_ACCESSIBLE_PUBLIC_PATHS
|
||||
):
|
||||
tenantAccessiblePublicPath = True
|
||||
|
||||
# Authenticating With API_KEY
|
||||
x_api_key = request.headers.get(RequestHeader.X_API_KEY)
|
||||
if (
|
||||
settings.INTERNAL_SERVICE_API_KEY
|
||||
and x_api_key == settings.INTERNAL_SERVICE_API_KEY
|
||||
): # Should API Key be in settings or just env alone?
|
||||
return self.get_response(request)
|
||||
|
||||
if not AuthenticationPluginRegistry.is_plugin_available():
|
||||
self.without_authentication(request, user)
|
||||
elif request.COOKIES:
|
||||
self.authenticate_with_cookies(request, tenantAccessiblePublicPath)
|
||||
if (
|
||||
request.user # type: ignore
|
||||
and request.session
|
||||
and "user" in request.session
|
||||
):
|
||||
response = self.get_response(request) # type: ignore
|
||||
return response
|
||||
return JsonResponse({"message": "Unauthorized"}, status=401)
|
||||
|
||||
def without_authentication(
|
||||
self, request: HttpRequest, user: Optional[User]
|
||||
) -> None:
|
||||
org_id = DefaultOrg.MOCK_ORG
|
||||
user_id = DefaultOrg.MOCK_USER_ID
|
||||
email = DefaultOrg.MOCK_USER_EMAIL
|
||||
user_session_info = CacheService.get_user_session_info(email)
|
||||
|
||||
if user is None:
|
||||
try:
|
||||
user_service = UserService()
|
||||
user = user_service.get_user_by_user_id(user_id)
|
||||
if not user:
|
||||
member = user_service.get_user_by_user_id(user_id)
|
||||
if member:
|
||||
user = member.user
|
||||
except AttributeError:
|
||||
pass
|
||||
if user is None:
|
||||
authentication_service = AuthenticationService()
|
||||
user = authentication_service.get_current_user()
|
||||
|
||||
if user:
|
||||
if not user_session_info:
|
||||
user_info: UserSessionInfo = UserSessionInfo(
|
||||
id=user.id,
|
||||
user_id=user.user_id,
|
||||
email=user.email,
|
||||
current_org=org_id,
|
||||
)
|
||||
CacheService.set_user_session_info(user_info)
|
||||
user_session_info = CacheService.get_user_session_info(email)
|
||||
|
||||
request.user = user
|
||||
request.org_id = org_id
|
||||
request.session["user"] = user_session_info
|
||||
request.session.save()
|
||||
|
||||
def authenticate_with_cookies(
|
||||
self,
|
||||
request: HttpRequest,
|
||||
tenantAccessiblePublicPath: bool,
|
||||
) -> None:
|
||||
z_code: str = request.COOKIES.get(Cookie.Z_CODE)
|
||||
token = cache.get(z_code) if z_code else None
|
||||
if not token:
|
||||
return
|
||||
|
||||
user_email = token["userinfo"]["email"]
|
||||
user_session_info = CacheService.get_user_session_info(user_email)
|
||||
if not user_session_info:
|
||||
return
|
||||
|
||||
current_org = user_session_info["current_org"]
|
||||
if not current_org:
|
||||
return
|
||||
|
||||
if (
|
||||
current_org != connection.get_schema()
|
||||
and not tenantAccessiblePublicPath
|
||||
):
|
||||
return
|
||||
|
||||
if (
|
||||
current_org == Common.PUBLIC_SCHEMA_NAME
|
||||
or tenantAccessiblePublicPath
|
||||
):
|
||||
user_service = UserService()
|
||||
else:
|
||||
organization_member_service = OrganizationMemberService()
|
||||
member = organization_member_service.get_user_by_email(user_email)
|
||||
if not member:
|
||||
return
|
||||
user_service = UserService()
|
||||
|
||||
user = user_service.get_user_by_email(user_email)
|
||||
|
||||
if not user:
|
||||
return
|
||||
|
||||
request.user = user
|
||||
request.org_id = current_org
|
||||
request.session["user"] = token
|
||||
request.session.save()
|
||||
13
backend/account/custom_authentication.py
Normal file
13
backend/account/custom_authentication.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from typing import Any
|
||||
|
||||
from django.http import HttpRequest
|
||||
from rest_framework.exceptions import AuthenticationFailed
|
||||
|
||||
|
||||
def api_login_required(view_func: Any) -> Any:
|
||||
def wrapper(request: HttpRequest, *args: Any, **kwargs: Any) -> Any:
|
||||
if request.user and request.session and "user" in request.session:
|
||||
return view_func(request, *args, **kwargs)
|
||||
raise AuthenticationFailed("Unauthorized")
|
||||
|
||||
return wrapper
|
||||
12
backend/account/custom_cache.py
Normal file
12
backend/account/custom_cache.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from django_redis import get_redis_connection
|
||||
|
||||
|
||||
class CustomCache:
|
||||
def __init__(self) -> None:
|
||||
self.cache = get_redis_connection("default")
|
||||
|
||||
def rpush(self, key: str, value: str) -> None:
|
||||
self.cache.rpush(key, value)
|
||||
|
||||
def lrem(self, key: str, value: str) -> None:
|
||||
self.cache.lrem(key, value)
|
||||
66
backend/account/custom_exceptions.py
Normal file
66
backend/account/custom_exceptions.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from typing import Optional
|
||||
|
||||
from rest_framework.exceptions import APIException
|
||||
|
||||
|
||||
class ConflictError(Exception):
|
||||
def __init__(self, message: str) -> None:
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class MethodNotImplemented(APIException):
|
||||
status_code = 501
|
||||
default_detail = "Method Not Implemented"
|
||||
|
||||
|
||||
class DuplicateData(APIException):
|
||||
status_code = 400
|
||||
default_detail = "Duplicate Data"
|
||||
|
||||
def __init__(
|
||||
self, detail: Optional[str] = None, code: Optional[int] = None
|
||||
):
|
||||
if detail is not None:
|
||||
self.detail = detail
|
||||
if code is not None:
|
||||
self.code = code
|
||||
super().__init__(detail, code)
|
||||
|
||||
|
||||
class TableNotExistError(APIException):
|
||||
status_code = 400
|
||||
default_detail = "Unknown Table"
|
||||
|
||||
def __init__(
|
||||
self, detail: Optional[str] = None, code: Optional[int] = None
|
||||
):
|
||||
if detail is not None:
|
||||
self.detail = detail
|
||||
if code is not None:
|
||||
self.code = code
|
||||
super().__init__()
|
||||
|
||||
|
||||
class UserNotExistError(APIException):
|
||||
status_code = 400
|
||||
default_detail = "Unknown User"
|
||||
|
||||
def __init__(
|
||||
self, detail: Optional[str] = None, code: Optional[int] = None
|
||||
):
|
||||
if detail is not None:
|
||||
self.detail = detail
|
||||
if code is not None:
|
||||
self.code = code
|
||||
super().__init__()
|
||||
|
||||
|
||||
class Forbidden(APIException):
|
||||
status_code = 403
|
||||
default_detail = "Do not have permission to perform this action."
|
||||
|
||||
|
||||
class UserAlreadyAssociatedException(APIException):
|
||||
status_code = 400
|
||||
default_detail = "User is already associated with one organization."
|
||||
130
backend/account/dto.py
Normal file
130
backend/account/dto.py
Normal file
@@ -0,0 +1,130 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemberData:
|
||||
user_id: str
|
||||
email: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
picture: Optional[str] = None
|
||||
role: Optional[list[str]] = None
|
||||
organization_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrganizationData:
|
||||
id: str
|
||||
display_name: str
|
||||
name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class CallbackData:
|
||||
user_id: str
|
||||
email: str
|
||||
token: Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrganizationSignupRequestBody:
|
||||
name: str
|
||||
display_name: str
|
||||
organization_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrganizationSignupResponse:
|
||||
name: str
|
||||
display_name: str
|
||||
organization_id: str
|
||||
created_at: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserInfo:
|
||||
email: str
|
||||
user_id: str
|
||||
id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
display_name: Optional[str] = None
|
||||
family_name: Optional[str] = None
|
||||
picture: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserSessionInfo:
|
||||
id: str
|
||||
user_id: str
|
||||
email: str
|
||||
current_org: str
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: dict[str, Any]) -> "UserSessionInfo":
|
||||
return UserSessionInfo(
|
||||
id=data["id"],
|
||||
user_id=data["user_id"],
|
||||
email=data["email"],
|
||||
current_org=data["current_org"],
|
||||
)
|
||||
|
||||
def to_dict(self) -> Any:
|
||||
return {
|
||||
"id": self.id,
|
||||
"user_id": self.user_id,
|
||||
"email": self.email,
|
||||
"current_org": self.current_org,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetUserReposne:
|
||||
user: UserInfo
|
||||
organizations: list[OrganizationData]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResetUserPasswordDto:
|
||||
status: bool
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserInviteResponse:
|
||||
email: str
|
||||
status: str
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserRoleData:
|
||||
name: str
|
||||
id: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemberInvitation:
|
||||
"""Represents an invitation to join an organization in Auth0.
|
||||
|
||||
Attributes:
|
||||
id (str): The unique identifier for the invitation.
|
||||
email (str): The user email.
|
||||
roles (List[str]): The roles assigned to the invitee.
|
||||
created_at (Optional[str]): The timestamp when the invitation
|
||||
was created.
|
||||
expires_at (Optional[str]): The timestamp when the invitation expires.
|
||||
"""
|
||||
|
||||
id: str
|
||||
email: str
|
||||
roles: list[str]
|
||||
created_at: Optional[str] = None
|
||||
expires_at: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserOrganizationRole:
|
||||
user_id: str
|
||||
role: UserRoleData
|
||||
organization_id: str
|
||||
6
backend/account/enums.py
Normal file
6
backend/account/enums.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class UserRole(Enum):
|
||||
USER = "user"
|
||||
ADMIN = "admin"
|
||||
26
backend/account/exceptions.py
Normal file
26
backend/account/exceptions.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from rest_framework.exceptions import APIException
|
||||
|
||||
|
||||
class UserIdNotExist(APIException):
|
||||
status_code = 404
|
||||
default_detail = "User ID does not exist"
|
||||
|
||||
|
||||
class UserAlreadyExistInOrganization(APIException):
|
||||
status_code = 403
|
||||
default_detail = "User allready exist in the organization"
|
||||
|
||||
|
||||
class OrganizationNotExist(APIException):
|
||||
status_code = 404
|
||||
default_detail = "Organization does not exist"
|
||||
|
||||
|
||||
class UnknownException(APIException):
|
||||
status_code = 500
|
||||
default_detail = "An unexpected error occurred"
|
||||
|
||||
|
||||
class BadRequestException(APIException):
|
||||
status_code = 400
|
||||
default_detail = "Bad Request"
|
||||
237
backend/account/migrations/0001_initial.py
Normal file
237
backend/account/migrations/0001_initial.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# Generated by Django 4.2.1 on 2023-07-18 10:39
|
||||
|
||||
import django.contrib.auth.models
|
||||
import django.contrib.auth.validators
|
||||
import django.db.models.deletion
|
||||
import django.utils.timezone
|
||||
import django_tenants.postgresql_backend.base
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
("auth", "0012_alter_user_first_name_max_length"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="User",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.BigAutoField(
|
||||
auto_created=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
verbose_name="ID",
|
||||
),
|
||||
),
|
||||
("password", models.CharField(max_length=128, verbose_name="password")),
|
||||
(
|
||||
"last_login",
|
||||
models.DateTimeField(
|
||||
blank=True, null=True, verbose_name="last login"
|
||||
),
|
||||
),
|
||||
(
|
||||
"is_superuser",
|
||||
models.BooleanField(
|
||||
default=False,
|
||||
help_text="Designates that this user has all permissions without explicitly assigning them.",
|
||||
verbose_name="superuser status",
|
||||
),
|
||||
),
|
||||
(
|
||||
"username",
|
||||
models.CharField(
|
||||
error_messages={
|
||||
"unique": "A user with that username already exists."
|
||||
},
|
||||
help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.",
|
||||
max_length=150,
|
||||
unique=True,
|
||||
validators=[
|
||||
django.contrib.auth.validators.UnicodeUsernameValidator()
|
||||
],
|
||||
verbose_name="username",
|
||||
),
|
||||
),
|
||||
(
|
||||
"first_name",
|
||||
models.CharField(
|
||||
blank=True, max_length=150, verbose_name="first name"
|
||||
),
|
||||
),
|
||||
(
|
||||
"last_name",
|
||||
models.CharField(
|
||||
blank=True, max_length=150, verbose_name="last name"
|
||||
),
|
||||
),
|
||||
(
|
||||
"email",
|
||||
models.EmailField(
|
||||
blank=True, max_length=254, verbose_name="email address"
|
||||
),
|
||||
),
|
||||
(
|
||||
"is_staff",
|
||||
models.BooleanField(
|
||||
default=False,
|
||||
help_text="Designates whether the user can log into this admin site.",
|
||||
verbose_name="staff status",
|
||||
),
|
||||
),
|
||||
(
|
||||
"is_active",
|
||||
models.BooleanField(
|
||||
default=True,
|
||||
help_text="Designates whether this user should be treated as active. Unselect this instead of deleting accounts.",
|
||||
verbose_name="active",
|
||||
),
|
||||
),
|
||||
(
|
||||
"date_joined",
|
||||
models.DateTimeField(
|
||||
default=django.utils.timezone.now, verbose_name="date joined"
|
||||
),
|
||||
),
|
||||
("user_id", models.CharField()),
|
||||
("project_storage_created", models.BooleanField(default=False)),
|
||||
("modified_at", models.DateTimeField(auto_now=True)),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
(
|
||||
"created_by",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
related_name="created_users",
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
(
|
||||
"groups",
|
||||
models.ManyToManyField(
|
||||
blank=True,
|
||||
related_name="customuser_set",
|
||||
related_query_name="customuser",
|
||||
to="auth.group",
|
||||
),
|
||||
),
|
||||
(
|
||||
"modified_by",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
related_name="modified_users",
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
(
|
||||
"user_permissions",
|
||||
models.ManyToManyField(
|
||||
blank=True,
|
||||
related_name="customuser_set",
|
||||
related_query_name="customuser",
|
||||
to="auth.permission",
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "user",
|
||||
"verbose_name_plural": "users",
|
||||
"abstract": False,
|
||||
},
|
||||
managers=[
|
||||
("objects", django.contrib.auth.models.UserManager()),
|
||||
],
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="Organization",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.BigAutoField(
|
||||
auto_created=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
verbose_name="ID",
|
||||
),
|
||||
),
|
||||
(
|
||||
"schema_name",
|
||||
models.CharField(
|
||||
db_index=True,
|
||||
max_length=63,
|
||||
unique=True,
|
||||
validators=[
|
||||
django_tenants.postgresql_backend.base._check_schema_name
|
||||
],
|
||||
),
|
||||
),
|
||||
("name", models.CharField(max_length=64)),
|
||||
("display_name", models.CharField(max_length=64)),
|
||||
("organization_id", models.CharField(max_length=64)),
|
||||
("modified_at", models.DateTimeField(auto_now=True)),
|
||||
("created_at", models.DateTimeField(auto_now=True)),
|
||||
(
|
||||
"created_by",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
related_name="created_orgs",
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
(
|
||||
"modified_by",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
related_name="modified_orgs",
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="Domain",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.BigAutoField(
|
||||
auto_created=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
verbose_name="ID",
|
||||
),
|
||||
),
|
||||
(
|
||||
"domain",
|
||||
models.CharField(db_index=True, max_length=253, unique=True),
|
||||
),
|
||||
("is_primary", models.BooleanField(db_index=True, default=True)),
|
||||
(
|
||||
"tenant",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="domains",
|
||||
to="account.organization",
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
]
|
||||
39
backend/account/migrations/0002_auto_20230718_1040.py
Normal file
39
backend/account/migrations/0002_auto_20230718_1040.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# mypy: ignore-errors
|
||||
# Generated by Django 4.2.1 on 2023-07-18 10:40
|
||||
|
||||
from django.contrib.auth.hashers import make_password
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
def create_public_tenant_and_domain(apps, schema_editor):
|
||||
Organization = apps.get_model("account", "Organization")
|
||||
|
||||
# public tenant
|
||||
tenant = Organization(
|
||||
name="public",
|
||||
display_name="public",
|
||||
organization_id="public",
|
||||
schema_name="public",
|
||||
)
|
||||
tenant.save()
|
||||
|
||||
User = apps.get_model("account", "User")
|
||||
# public User admin
|
||||
user = User(
|
||||
username="admin",
|
||||
email="admin@zipstack.com",
|
||||
is_superuser=True,
|
||||
is_staff=True,
|
||||
password=make_password("ascon"),
|
||||
)
|
||||
user.save()
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("account", "0001_initial"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RunPython(create_public_tenant_and_domain),
|
||||
]
|
||||
65
backend/account/migrations/0003_platformkey.py
Normal file
65
backend/account/migrations/0003_platformkey.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# Generated by Django 4.2.1 on 2023-11-02 05:22
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
import uuid
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("account", "0002_auto_20230718_1040"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="PlatformKey",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.UUIDField(
|
||||
default=uuid.uuid4,
|
||||
editable=False,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
),
|
||||
),
|
||||
("key", models.UUIDField(default=uuid.uuid4)),
|
||||
(
|
||||
"key_name",
|
||||
models.CharField(blank=True, max_length=64, null=True, unique=True),
|
||||
),
|
||||
("is_active", models.BooleanField(default=False)),
|
||||
(
|
||||
"created_by",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
related_name="created_keys",
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
(
|
||||
"modified_by",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
related_name="modified_keys",
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
(
|
||||
"organization",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
related_name="related_org",
|
||||
to="account.organization",
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,23 @@
|
||||
# Generated by Django 4.2.1 on 2023-11-15 11:37
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("account", "0003_platformkey"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="platformkey",
|
||||
name="key_name",
|
||||
field=models.CharField(blank=True, default="", max_length=64),
|
||||
),
|
||||
migrations.AddConstraint(
|
||||
model_name="platformkey",
|
||||
constraint=models.UniqueConstraint(
|
||||
fields=("key_name", "organization"), name="unique_key_name"
|
||||
),
|
||||
),
|
||||
]
|
||||
39
backend/account/migrations/0005_encryptionsecret.py
Normal file
39
backend/account/migrations/0005_encryptionsecret.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Generated by Django 4.2.1 on 2024-02-13 11:52
|
||||
|
||||
from typing import Any
|
||||
|
||||
from account.models import EncryptionSecret
|
||||
from cryptography.fernet import Fernet
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("account", "0004_alter_platformkey_key_name_and_more"),
|
||||
]
|
||||
|
||||
def initialize_secret(apps: Any, schema_editor: Any) -> None:
|
||||
EncryptionSecret.objects.create(
|
||||
key=Fernet.generate_key().decode("utf-8")
|
||||
)
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="EncryptionSecret",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.BigAutoField(
|
||||
auto_created=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
verbose_name="ID",
|
||||
),
|
||||
),
|
||||
("key", models.CharField(blank=True, max_length=64)),
|
||||
],
|
||||
),
|
||||
migrations.RunPython(
|
||||
initialize_secret, reverse_code=migrations.RunPython.noop
|
||||
),
|
||||
]
|
||||
0
backend/account/migrations/__init__.py
Normal file
0
backend/account/migrations/__init__.py
Normal file
141
backend/account/models.py
Normal file
141
backend/account/models.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import uuid
|
||||
|
||||
from backend.constants import FieldLengthConstants as FieldLength
|
||||
from django.contrib.auth.models import AbstractUser, Group, Permission
|
||||
from django.db import models
|
||||
from django_tenants.models import DomainMixin, TenantMixin
|
||||
|
||||
NAME_SIZE = 64
|
||||
KEY_SIZE = 64
|
||||
|
||||
|
||||
class Organization(TenantMixin):
|
||||
"""Stores data related to an organization.
|
||||
|
||||
The fields created_by and modified_by is updated after a
|
||||
:model:`account.User` is created.
|
||||
"""
|
||||
|
||||
name = models.CharField(max_length=NAME_SIZE)
|
||||
display_name = models.CharField(max_length=NAME_SIZE)
|
||||
organization_id = models.CharField(max_length=FieldLength.ORG_NAME_SIZE)
|
||||
created_by = models.ForeignKey(
|
||||
"User",
|
||||
on_delete=models.SET_NULL,
|
||||
related_name="created_orgs",
|
||||
null=True,
|
||||
blank=True,
|
||||
)
|
||||
modified_by = models.ForeignKey(
|
||||
"User",
|
||||
on_delete=models.SET_NULL,
|
||||
related_name="modified_orgs",
|
||||
null=True,
|
||||
blank=True,
|
||||
)
|
||||
modified_at = models.DateTimeField(auto_now=True)
|
||||
created_at = models.DateTimeField(auto_now=True)
|
||||
|
||||
auto_create_schema = True
|
||||
|
||||
|
||||
class Domain(DomainMixin):
|
||||
pass
|
||||
|
||||
|
||||
class User(AbstractUser):
|
||||
"""Stores data related to a user belonging to any organization.
|
||||
|
||||
Every org, user is assumed to be unique.
|
||||
"""
|
||||
|
||||
# Third Party Authentication User ID
|
||||
user_id = models.CharField()
|
||||
project_storage_created = models.BooleanField(default=False)
|
||||
created_by = models.ForeignKey(
|
||||
"User",
|
||||
on_delete=models.SET_NULL,
|
||||
related_name="created_users",
|
||||
null=True,
|
||||
blank=True,
|
||||
)
|
||||
modified_by = models.ForeignKey(
|
||||
"User",
|
||||
on_delete=models.SET_NULL,
|
||||
related_name="modified_users",
|
||||
null=True,
|
||||
blank=True,
|
||||
)
|
||||
modified_at = models.DateTimeField(auto_now=True)
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
# Specify a unique related_name for the groups field
|
||||
groups = models.ManyToManyField(
|
||||
Group,
|
||||
related_name="customuser_set",
|
||||
related_query_name="customuser",
|
||||
blank=True,
|
||||
)
|
||||
|
||||
# Specify a unique related_name for the user_permissions field
|
||||
user_permissions = models.ManyToManyField(
|
||||
Permission,
|
||||
related_name="customuser_set",
|
||||
related_query_name="customuser",
|
||||
blank=True,
|
||||
)
|
||||
|
||||
def __str__(self): # type: ignore
|
||||
return f"User({self.id}, email: {self.email}, userId: {self.user_id})"
|
||||
|
||||
|
||||
class PlatformKey(models.Model):
|
||||
"""Model to hold details of Platform keys.
|
||||
|
||||
Only users with admin role are allowed to perform any operation
|
||||
related keys.
|
||||
"""
|
||||
|
||||
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
|
||||
key = models.UUIDField(default=uuid.uuid4)
|
||||
key_name = models.CharField(
|
||||
max_length=KEY_SIZE, null=False, blank=True, default=""
|
||||
)
|
||||
is_active = models.BooleanField(default=False)
|
||||
organization = models.ForeignKey(
|
||||
"Organization",
|
||||
on_delete=models.SET_NULL,
|
||||
related_name="related_org",
|
||||
null=True,
|
||||
blank=True,
|
||||
)
|
||||
created_by = models.ForeignKey(
|
||||
"User",
|
||||
on_delete=models.SET_NULL,
|
||||
related_name="created_keys",
|
||||
null=True,
|
||||
blank=True,
|
||||
)
|
||||
modified_by = models.ForeignKey(
|
||||
"User",
|
||||
on_delete=models.SET_NULL,
|
||||
related_name="modified_keys",
|
||||
null=True,
|
||||
blank=True,
|
||||
)
|
||||
|
||||
class Meta:
|
||||
constraints = [
|
||||
models.UniqueConstraint(
|
||||
fields=["key_name", "organization"],
|
||||
name="unique_key_name",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class EncryptionSecret(models.Model):
|
||||
key = models.CharField(
|
||||
max_length=KEY_SIZE,
|
||||
null=False,
|
||||
blank=True,
|
||||
)
|
||||
42
backend/account/organization.py
Normal file
42
backend/account/organization.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from account.models import Domain, Organization
|
||||
from django.db import IntegrityError
|
||||
|
||||
Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OrganizationService:
|
||||
def __init__(self): # type: ignore
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_organization_by_org_id(org_id: str) -> Optional[Organization]:
|
||||
try:
|
||||
return Organization.objects.get(organization_id=org_id) # type: ignore
|
||||
except Organization.DoesNotExist:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def create_organization(
|
||||
name: str, display_name: str, organization_id: str
|
||||
) -> Organization:
|
||||
try:
|
||||
organization: Organization = Organization(
|
||||
name=name,
|
||||
display_name=display_name,
|
||||
organization_id=organization_id,
|
||||
schema_name=organization_id,
|
||||
)
|
||||
organization.save()
|
||||
except IntegrityError as error:
|
||||
Logger.info(f"[Duplicate Id] Failed to create Organization Error: {error}")
|
||||
raise error
|
||||
# Add one or more domains for the tenant
|
||||
domain = Domain()
|
||||
domain.domain = organization_id
|
||||
domain.tenant = organization
|
||||
domain.is_primary = True
|
||||
domain.save()
|
||||
return organization
|
||||
86
backend/account/serializer.py
Normal file
86
backend/account/serializer.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import re
|
||||
|
||||
# from account.enums import Region
|
||||
from account.models import Organization, User
|
||||
from rest_framework import serializers
|
||||
|
||||
|
||||
class OrganizationSignupSerializer(serializers.Serializer):
|
||||
name = serializers.CharField(required=True, max_length=150)
|
||||
display_name = serializers.CharField(required=True, max_length=150)
|
||||
organization_id = serializers.CharField(required=True, max_length=30)
|
||||
|
||||
def validate_organization_id(self, value): # type: ignore
|
||||
if not re.match(r"^[a-z0-9_-]+$", value):
|
||||
raise serializers.ValidationError(
|
||||
"organization_code should only contain alphanumeric characters,_ and -."
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
class OrganizationCallbackSerializer(serializers.Serializer):
|
||||
id = serializers.CharField(required=False)
|
||||
|
||||
|
||||
class GetOrganizationsResponseSerializer(serializers.Serializer):
|
||||
id = serializers.CharField()
|
||||
display_name = serializers.CharField()
|
||||
name = serializers.CharField()
|
||||
# Add more fields as needed
|
||||
|
||||
def to_representation(self, instance): # type: ignore
|
||||
data = super().to_representation(instance)
|
||||
# Modify the representation if needed
|
||||
return data
|
||||
|
||||
|
||||
class GetOrganizationMembersResponseSerializer(serializers.Serializer):
|
||||
user_id = serializers.CharField()
|
||||
email = serializers.CharField()
|
||||
name = serializers.CharField()
|
||||
picture = serializers.CharField()
|
||||
# Add more fields as needed
|
||||
|
||||
def to_representation(self, instance): # type: ignore
|
||||
data = super().to_representation(instance)
|
||||
# Modify the representation if needed
|
||||
return data
|
||||
|
||||
|
||||
class OrganizationSerializer(serializers.Serializer):
|
||||
name = serializers.CharField()
|
||||
organization_id = serializers.CharField()
|
||||
|
||||
|
||||
class SetOrganizationsResponseSerializer(serializers.Serializer):
|
||||
id = serializers.CharField()
|
||||
email = serializers.CharField()
|
||||
name = serializers.CharField()
|
||||
display_name = serializers.CharField()
|
||||
family_name = serializers.CharField()
|
||||
picture = serializers.CharField()
|
||||
# Add more fields as needed
|
||||
|
||||
def to_representation(self, instance): # type: ignore
|
||||
data = super().to_representation(instance)
|
||||
# Modify the representation if needed
|
||||
return data
|
||||
|
||||
|
||||
class ModelTenantSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Organization
|
||||
fields = fields = ("name", "created_on")
|
||||
|
||||
|
||||
class UserSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = User
|
||||
fields = ("id", "email")
|
||||
|
||||
|
||||
class OrganizationSignupResponseSerializer(serializers.Serializer):
|
||||
name = serializers.CharField()
|
||||
display_name = serializers.CharField()
|
||||
organization_id = serializers.CharField()
|
||||
created_at = serializers.CharField()
|
||||
11
backend/account/templates/index.html
Normal file
11
backend/account/templates/index.html
Normal file
@@ -0,0 +1,11 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en" xml:lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<title>ZipstackID Django App Example</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1 id="profileDropDown">Welcome Guest</h1>
|
||||
<p><a href="{% url 'login' %}" id="qsLoginBtn">Login</a></p>
|
||||
</body>
|
||||
</html>
|
||||
1
backend/account/tests.py
Normal file
1
backend/account/tests.py
Normal file
@@ -0,0 +1 @@
|
||||
# Create your tests here.
|
||||
20
backend/account/urls.py
Normal file
20
backend/account/urls.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from account.views import (
|
||||
callback,
|
||||
create_organization,
|
||||
get_organizations,
|
||||
login,
|
||||
logout,
|
||||
set_organization,
|
||||
signup,
|
||||
)
|
||||
from django.urls import path
|
||||
|
||||
urlpatterns = [
|
||||
path("login", login, name="login"),
|
||||
path("signup", signup, name="signup"),
|
||||
path("logout", logout, name="logout"),
|
||||
path("callback", callback, name="callback"),
|
||||
path("organization", get_organizations, name="get_organizations"),
|
||||
path("organization/<str:id>/set", set_organization, name="set_organization"),
|
||||
path("organization/create", create_organization, name="create_organization"),
|
||||
]
|
||||
50
backend/account/user.py
Normal file
50
backend/account/user.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from account.models import User
|
||||
from django.db import IntegrityError
|
||||
|
||||
Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserService:
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def create_user(self, email: str, user_id: str) -> User:
|
||||
try:
|
||||
user: User = User(email=email, user_id=user_id, username=email)
|
||||
user.save()
|
||||
except IntegrityError as error:
|
||||
Logger.info(f"[Duplicate Id] Failed to create User Error: {error}")
|
||||
raise error
|
||||
return user
|
||||
|
||||
def get_user_by_email(self, email: str) -> Optional[User]:
|
||||
try:
|
||||
user: User = User.objects.get(email=email)
|
||||
return user
|
||||
except User.DoesNotExist:
|
||||
return None
|
||||
|
||||
def get_user_by_user_id(self, user_id: str) -> Any:
|
||||
try:
|
||||
return User.objects.get(user_id=user_id)
|
||||
except User.DoesNotExist:
|
||||
return None
|
||||
|
||||
def get_user_by_id(self, id: str) -> Any:
|
||||
"""Retrieve a user by their ID, taking into account the schema context.
|
||||
|
||||
Args:
|
||||
id (str): The ID of the user.
|
||||
|
||||
Returns:
|
||||
Any: The user object if found, or None if not found.
|
||||
"""
|
||||
try:
|
||||
return User.objects.get(id=id)
|
||||
except User.DoesNotExist:
|
||||
return None
|
||||
125
backend/account/views.py
Normal file
125
backend/account/views.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from account.authentication_controller import AuthenticationController
|
||||
from account.dto import (
|
||||
OrganizationSignupRequestBody,
|
||||
OrganizationSignupResponse,
|
||||
)
|
||||
from account.models import Organization
|
||||
from account.organization import OrganizationService
|
||||
from account.serializer import (
|
||||
OrganizationSignupResponseSerializer,
|
||||
OrganizationSignupSerializer,
|
||||
)
|
||||
from rest_framework import status
|
||||
from rest_framework.decorators import api_view
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
|
||||
Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@api_view(["POST"])
|
||||
def create_organization(request: Request) -> Response:
|
||||
serializer = OrganizationSignupSerializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
try:
|
||||
requestBody: OrganizationSignupRequestBody = makeSignupRequestParams(
|
||||
serializer
|
||||
)
|
||||
|
||||
organization: Organization = OrganizationService.create_organization(
|
||||
requestBody.name,
|
||||
requestBody.display_name,
|
||||
requestBody.organization_id,
|
||||
)
|
||||
response = makeSignupResponse(organization)
|
||||
return Response(
|
||||
status=status.HTTP_201_CREATED,
|
||||
data={"message": "success", "tenant": response},
|
||||
)
|
||||
except Exception as error:
|
||||
Logger.error(error)
|
||||
return Response(
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR, data="Unknown Error"
|
||||
)
|
||||
|
||||
|
||||
@api_view(["GET"])
|
||||
def callback(request: Request) -> Response:
|
||||
auth_controller = AuthenticationController()
|
||||
return auth_controller.authorization_callback(request)
|
||||
|
||||
|
||||
@api_view(["GET"])
|
||||
def login(request: Request) -> Response:
|
||||
auth_controller = AuthenticationController()
|
||||
return auth_controller.user_login(request)
|
||||
|
||||
|
||||
@api_view(["GET"])
|
||||
def signup(request: Request) -> Response:
|
||||
auth_controller = AuthenticationController()
|
||||
return auth_controller.user_signup(request)
|
||||
|
||||
|
||||
@api_view(["GET"])
|
||||
def logout(request: Request) -> Response:
|
||||
auth_controller = AuthenticationController()
|
||||
return auth_controller.user_logout(request)
|
||||
|
||||
|
||||
@api_view(["GET"])
|
||||
def get_organizations(request: Request) -> Response:
|
||||
"""get_organizations.
|
||||
|
||||
Retrieve the list of organizations to which the user belongs.
|
||||
Args:
|
||||
request (HttpRequest): _description_
|
||||
|
||||
Returns:
|
||||
Response: A list of organizations with associated information.
|
||||
"""
|
||||
auth_controller = AuthenticationController()
|
||||
return auth_controller.user_organizations(request)
|
||||
|
||||
|
||||
@api_view(["POST"])
|
||||
def set_organization(request: Request, id: str) -> Response:
|
||||
"""set_organization.
|
||||
|
||||
Set the current organization to use.
|
||||
Args:
|
||||
request (HttpRequest): _description_
|
||||
id (String): organization Id
|
||||
|
||||
Returns:
|
||||
Response: Contains the User and Current organization details.
|
||||
"""
|
||||
|
||||
auth_controller = AuthenticationController()
|
||||
return auth_controller.set_user_organization(request, id)
|
||||
|
||||
|
||||
def makeSignupRequestParams(
|
||||
serializer: OrganizationSignupSerializer,
|
||||
) -> OrganizationSignupRequestBody:
|
||||
return OrganizationSignupRequestBody(
|
||||
serializer.validated_data["name"],
|
||||
serializer.validated_data["display_name"],
|
||||
serializer.validated_data["organization_id"],
|
||||
)
|
||||
|
||||
|
||||
def makeSignupResponse(
|
||||
organization: Organization,
|
||||
) -> Any:
|
||||
return OrganizationSignupResponseSerializer(
|
||||
OrganizationSignupResponse(
|
||||
organization.name,
|
||||
organization.display_name,
|
||||
organization.organization_id,
|
||||
organization.created_at,
|
||||
)
|
||||
).data
|
||||
0
backend/adapter_processor/__init__.py
Normal file
0
backend/adapter_processor/__init__.py
Normal file
280
backend/adapter_processor/adapter_processor.py
Normal file
280
backend/adapter_processor/adapter_processor.py
Normal file
@@ -0,0 +1,280 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
import adapter_processor
|
||||
from account.models import User
|
||||
from adapter_processor.constants import AdapterKeys
|
||||
from adapter_processor.exceptions import (
|
||||
InternalServiceError,
|
||||
InValidAdapterId,
|
||||
TestAdapterException,
|
||||
TestAdapterInputException,
|
||||
)
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from platform_settings.exceptions import ActiveKeyNotFound
|
||||
from platform_settings.platform_auth_service import (
|
||||
PlatformAuthenticationService,
|
||||
)
|
||||
from unstract.adapters.adapterkit import Adapterkit
|
||||
from unstract.adapters.base import Adapter
|
||||
from unstract.adapters.enums import AdapterTypes
|
||||
from unstract.adapters.exceptions import AdapterError
|
||||
from unstract.adapters.x2text.constants import X2TextConstants
|
||||
|
||||
from .models import AdapterInstance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdapterProcessor:
|
||||
@staticmethod
|
||||
def get_json_schema(adapter_id: str) -> dict[str, Any]:
|
||||
"""Function to return JSON Schema for Adapters."""
|
||||
schema_details: dict[str, Any] = {}
|
||||
updated_adapters = AdapterProcessor.__fetch_adapters_by_key_value(
|
||||
AdapterKeys.ID, adapter_id
|
||||
)
|
||||
if len(updated_adapters) != 0:
|
||||
try:
|
||||
schema_details[AdapterKeys.JSON_SCHEMA] = json.loads(
|
||||
updated_adapters[0].get(AdapterKeys.JSON_SCHEMA)
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error occured while parsing JSON Schema : {exc}")
|
||||
raise InternalServiceError()
|
||||
else:
|
||||
logger.error(
|
||||
f"Invalid adapter Id : {adapter_id} while fetching JSON Schema"
|
||||
)
|
||||
raise InValidAdapterId()
|
||||
return schema_details
|
||||
|
||||
@staticmethod
|
||||
def get_all_supported_adapters(type: str) -> list[dict[Any, Any]]:
|
||||
"""Function to return list of all supported adapters."""
|
||||
supported_adapters = []
|
||||
updated_adapters = []
|
||||
updated_adapters = AdapterProcessor.__fetch_adapters_by_key_value(
|
||||
AdapterKeys.ADAPTER_TYPE, type
|
||||
)
|
||||
for each_adapter in updated_adapters:
|
||||
supported_adapters.append(
|
||||
{
|
||||
AdapterKeys.ID: each_adapter.get(AdapterKeys.ID),
|
||||
AdapterKeys.NAME: each_adapter.get(AdapterKeys.NAME),
|
||||
AdapterKeys.DESCRIPTION: each_adapter.get(
|
||||
AdapterKeys.DESCRIPTION
|
||||
),
|
||||
AdapterKeys.ICON: each_adapter.get(AdapterKeys.ICON),
|
||||
AdapterKeys.ADAPTER_TYPE: each_adapter.get(
|
||||
AdapterKeys.ADAPTER_TYPE
|
||||
),
|
||||
}
|
||||
)
|
||||
return supported_adapters
|
||||
|
||||
@staticmethod
|
||||
def get_adapter_data_with_key(adapter_id: str, key_value: str) -> Any:
|
||||
"""Generic Function to get adapter data with provided key."""
|
||||
updated_adapters = AdapterProcessor.__fetch_adapters_by_key_value(
|
||||
"id", adapter_id
|
||||
)
|
||||
if len(updated_adapters) == 0:
|
||||
logger.error(
|
||||
f"Invalid adapter ID {adapter_id} while invoking utility"
|
||||
)
|
||||
raise InValidAdapterId()
|
||||
return AdapterProcessor.__fetch_adapters_by_key_value("id", adapter_id)[
|
||||
0
|
||||
].get(key_value)
|
||||
|
||||
@staticmethod
|
||||
def test_adapter(adapter_id: str, adapter_metadata: dict[str, Any]) -> bool:
|
||||
logger.info(f"Testing adapter: {adapter_id}")
|
||||
try:
|
||||
adapter_class = Adapterkit().get_adapter_class_by_adapter_id(
|
||||
adapter_id
|
||||
)
|
||||
|
||||
if (
|
||||
adapter_metadata.pop(AdapterKeys.ADAPTER_TYPE)
|
||||
== AdapterKeys.X2TEXT
|
||||
):
|
||||
adapter_metadata[
|
||||
X2TextConstants.X2TEXT_HOST
|
||||
] = settings.X2TEXT_HOST
|
||||
adapter_metadata[
|
||||
X2TextConstants.X2TEXT_PORT
|
||||
] = settings.X2TEXT_PORT
|
||||
platform_key = (
|
||||
PlatformAuthenticationService.get_active_platform_key()
|
||||
)
|
||||
adapter_metadata[
|
||||
X2TextConstants.PLATFORM_SERVICE_API_KEY
|
||||
] = str(platform_key.key)
|
||||
|
||||
adapter_instance = adapter_class(adapter_metadata)
|
||||
test_result: bool = adapter_instance.test_connection()
|
||||
logger.info(f"{adapter_id} test result: {test_result}")
|
||||
return test_result
|
||||
except ActiveKeyNotFound:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error while testing {adapter_id}: {e}")
|
||||
if isinstance(e, AdapterError):
|
||||
raise TestAdapterInputException(str(e.message))
|
||||
elif isinstance(e, ActiveKeyNotFound):
|
||||
raise e
|
||||
else:
|
||||
raise TestAdapterException()
|
||||
|
||||
@staticmethod
|
||||
def __fetch_adapters_by_key_value(key: str, value: Any) -> Adapter:
|
||||
"""Fetches a list of adapters that have an attribute matching key and
|
||||
value."""
|
||||
logger.info(f"Fetching adapter list for {key} with {value}")
|
||||
adapter_kit = Adapterkit()
|
||||
adapters = adapter_kit.get_adapters_list()
|
||||
return [iterate for iterate in adapters if iterate[key] == value]
|
||||
|
||||
@staticmethod
|
||||
def set_default_triad(default_triad: dict[str, str], user: User) -> None:
|
||||
filter_params: dict[str, Any] = {}
|
||||
|
||||
try:
|
||||
for key in default_triad:
|
||||
filter_params.clear()
|
||||
adapter_id = default_triad[key]
|
||||
# Query rows where adapter_type=X and is_default=True
|
||||
if key == AdapterKeys.LLM_DEFAULT:
|
||||
adapter_type = AdapterTypes.LLM.name
|
||||
elif key == AdapterKeys.EMBEDDING_DEFAULT:
|
||||
adapter_type = AdapterTypes.EMBEDDING.name
|
||||
elif key == AdapterKeys.VECTOR_DB_DEFAULT:
|
||||
adapter_type = AdapterTypes.VECTOR_DB.name
|
||||
|
||||
filter_params["adapter_type"] = adapter_type
|
||||
filter_params["is_default"] = True
|
||||
filter_params["created_by"] = user
|
||||
|
||||
AdapterInstance.objects.filter(**filter_params).update(
|
||||
is_default=False
|
||||
)
|
||||
|
||||
# Update the adapter_id in the incoming
|
||||
# list to set is_default=True
|
||||
filter_params.clear()
|
||||
try:
|
||||
new_adapter_default: AdapterInstance = (
|
||||
AdapterInstance.objects.get(pk=adapter_id)
|
||||
)
|
||||
new_adapter_default.is_default = True
|
||||
new_adapter_default.save()
|
||||
except (
|
||||
adapter_processor.models.AdapterInstance.DoesNotExist
|
||||
) as e:
|
||||
logger.error(
|
||||
f"Error while retrieving adapter: {adapter_id} "
|
||||
f"reason: {e}"
|
||||
)
|
||||
raise InValidAdapterId()
|
||||
logger.info("Changed defaults successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to save defaults because: {e}")
|
||||
if isinstance(e, InValidAdapterId):
|
||||
raise e
|
||||
else:
|
||||
raise InternalServiceError()
|
||||
|
||||
@staticmethod
|
||||
def get_adapter_instance_by_id(adapter_instance_id: str) -> Adapter:
|
||||
"""Get the adapter instance by its ID.
|
||||
|
||||
Parameters:
|
||||
- adapter_instance_id (str): The ID of the adapter instance.
|
||||
|
||||
Returns:
|
||||
- Adapter: The adapter instance with the specified ID.
|
||||
|
||||
Raises:
|
||||
- Exception: If there is an error while fetching the adapter instance.
|
||||
"""
|
||||
try:
|
||||
adapter = AdapterInstance.objects.get(id=adapter_instance_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to fetch adapter: {e}")
|
||||
if not adapter:
|
||||
logger.error("Unable to fetch adapter")
|
||||
return adapter.adapter_name
|
||||
|
||||
@staticmethod
|
||||
def get_adapters_by_type(
|
||||
adapter_type: AdapterTypes, user: User
|
||||
) -> list[AdapterInstance]:
|
||||
"""Get a list of adapters by their type.
|
||||
|
||||
Parameters:
|
||||
- adapter_type (AdapterTypes): The type of adapters to retrieve.
|
||||
|
||||
Returns:
|
||||
- list[AdapterInstance]: A list of AdapterInstance objects that match
|
||||
the specified adapter type.
|
||||
"""
|
||||
adapters: list[AdapterInstance] = AdapterInstance.objects.filter(
|
||||
adapter_type=adapter_type.value, created_by=user
|
||||
)
|
||||
return adapters
|
||||
|
||||
@staticmethod
|
||||
def get_adapter_by_name_and_type(
|
||||
adapter_type: AdapterTypes,
|
||||
adapter_name: Optional[str] = None,
|
||||
) -> Optional[AdapterInstance]:
|
||||
"""Get the adapter instance by its name and type.
|
||||
|
||||
Parameters:
|
||||
- adapter_name (str): The name of the adapter instance.
|
||||
- adapter_type (AdapterTypes): The type of the adapter instance.
|
||||
|
||||
Returns:
|
||||
- AdapterInstance: The adapter with the specified name and type.
|
||||
"""
|
||||
if adapter_name:
|
||||
adapter: AdapterInstance = AdapterInstance.objects.get(
|
||||
adapter_name=adapter_name, adapter_type=adapter_type.value
|
||||
)
|
||||
else:
|
||||
try:
|
||||
adapter = AdapterInstance.objects.get(
|
||||
adapter_type=adapter_type.value, is_default=True
|
||||
)
|
||||
except AdapterInstance.DoesNotExist:
|
||||
return None
|
||||
return adapter
|
||||
|
||||
@staticmethod
|
||||
def get_default_adapters(user: User) -> list[AdapterInstance]:
|
||||
"""Retrieve a list of default adapter instances. This method queries
|
||||
the database to fetch all adapter instances marked as default.
|
||||
|
||||
Raises:
|
||||
InternalServiceError: If an unexpected error occurs during
|
||||
the database query.
|
||||
|
||||
Returns:
|
||||
list[AdapterInstance]: A list of AdapterInstance objects that are
|
||||
marked as default.
|
||||
"""
|
||||
try:
|
||||
adapters: list[AdapterInstance] = AdapterInstance.objects.filter(
|
||||
is_default=True, created_by=user
|
||||
)
|
||||
return adapters
|
||||
except ObjectDoesNotExist as e:
|
||||
logger.error(f"No default adapters found: {e}")
|
||||
raise InternalServiceError("No default adapters found")
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while fetching default adapters: {e}")
|
||||
raise InternalServiceError("Error fetching default adapters")
|
||||
23
backend/adapter_processor/constants.py
Normal file
23
backend/adapter_processor/constants.py
Normal file
@@ -0,0 +1,23 @@
|
||||
class AdapterKeys:
|
||||
JSON_SCHEMA = "json_schema"
|
||||
ADAPTER_TYPE = "adapter_type"
|
||||
IS_DEFAULT = "is_default"
|
||||
LLM = "LLM"
|
||||
X2TEXT = "X2TEXT"
|
||||
VECTOR_DB = "VECTOR_DB"
|
||||
EMBEDDING = "EMBEDDING"
|
||||
NAME = "name"
|
||||
DESCRIPTION = "description"
|
||||
ICON = "icon"
|
||||
ADAPTER_ID = "adapter_id"
|
||||
ADAPTER_METADATA = "adapter_metadata"
|
||||
ADAPTER_METADATA_B = "adapter_metadata_b"
|
||||
ID = "id"
|
||||
IS_VALID = "is_valid"
|
||||
LLM_DEFAULT = "llm_default"
|
||||
VECTOR_DB_DEFAULT = "vector_db_default"
|
||||
EMBEDDING_DEFAULT = "embedding_default"
|
||||
ADAPTER_NAME_EXISTS = (
|
||||
"Configuration with this ID already exists. "
|
||||
"Please try with a different ID"
|
||||
)
|
||||
55
backend/adapter_processor/exceptions.py
Normal file
55
backend/adapter_processor/exceptions.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from backend.exceptions import UnstractBaseException
|
||||
from rest_framework.exceptions import APIException
|
||||
|
||||
|
||||
class IdIsMandatory(APIException):
|
||||
status_code = 400
|
||||
default_detail = "ID is Mandatory."
|
||||
|
||||
|
||||
class InValidType(APIException):
|
||||
status_code = 400
|
||||
default_detail = "Type is not Valid."
|
||||
|
||||
|
||||
class InValidAdapterId(APIException):
|
||||
status_code = 400
|
||||
default_detail = "Adapter ID is not Valid."
|
||||
|
||||
|
||||
class JSONParseException(APIException):
|
||||
status_code = 500
|
||||
default_detail = "Exception occured while Parsing JSON Schema."
|
||||
|
||||
|
||||
class InternalServiceError(APIException):
|
||||
status_code = 500
|
||||
default_detail = "Internal Service error"
|
||||
|
||||
|
||||
class CannotDeleteDefaultAdapter(APIException):
|
||||
status_code = 500
|
||||
default_detail = (
|
||||
"This is configured as default and cannot be deleted. "
|
||||
"Please configure a different default before you try again!"
|
||||
)
|
||||
|
||||
|
||||
class UniqueConstraintViolation(APIException):
|
||||
status_code = 400
|
||||
default_detail = "Unique constraint violated"
|
||||
|
||||
|
||||
class TestAdapterException(APIException):
|
||||
status_code = 500
|
||||
default_detail = "Error while testing adapter."
|
||||
|
||||
|
||||
class TestAdapterInputException(UnstractBaseException):
|
||||
status_code = 400
|
||||
default_detail = "Connection test failed using the given configuration data"
|
||||
|
||||
|
||||
class ErrorFetchingAdapterData(UnstractBaseException):
|
||||
status_code = 400
|
||||
default_detail = "Error while fetching adapter data."
|
||||
109
backend/adapter_processor/migrations/0001_initial.py
Normal file
109
backend/adapter_processor/migrations/0001_initial.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# Generated by Django 4.2.1 on 2024-01-23 11:18
|
||||
|
||||
import uuid
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="AdapterInstance",
|
||||
fields=[
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("modified_at", models.DateTimeField(auto_now=True)),
|
||||
(
|
||||
"id",
|
||||
models.UUIDField(
|
||||
db_comment="Unique identifier for the Adapter Instance",
|
||||
default=uuid.uuid4,
|
||||
editable=False,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
),
|
||||
),
|
||||
(
|
||||
"adapter_name",
|
||||
models.TextField(
|
||||
db_comment="Name of the Adapter Instance",
|
||||
max_length=128,
|
||||
),
|
||||
),
|
||||
(
|
||||
"adapter_id",
|
||||
models.CharField(
|
||||
db_comment="Unique identifier of the Adapter",
|
||||
default="",
|
||||
max_length=128,
|
||||
),
|
||||
),
|
||||
(
|
||||
"adapter_metadata",
|
||||
models.JSONField(
|
||||
db_column="adapter_metadata",
|
||||
db_comment="JSON adapter metadata submitted by the user",
|
||||
default=dict,
|
||||
),
|
||||
),
|
||||
(
|
||||
"adapter_type",
|
||||
models.CharField(
|
||||
choices=[
|
||||
("UNKNOWN", "UNKNOWN"),
|
||||
("LLM", "LLM"),
|
||||
("EMBEDDING", "EMBEDDING"),
|
||||
("VECTOR_DB", "VECTOR_DB"),
|
||||
],
|
||||
db_comment="Type of adapter LLM/EMBEDDING/VECTOR_DB",
|
||||
),
|
||||
),
|
||||
(
|
||||
"is_active",
|
||||
models.BooleanField(
|
||||
db_comment="Is the adapter instance currently being used",
|
||||
default=False,
|
||||
),
|
||||
),
|
||||
(
|
||||
"is_default",
|
||||
models.BooleanField(
|
||||
db_comment="Is the adapter instance default",
|
||||
default=False,
|
||||
),
|
||||
),
|
||||
(
|
||||
"created_by",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
related_name="created_adapters",
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
(
|
||||
"modified_by",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
related_name="modified_adapters",
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "adapter_adapterinstance",
|
||||
"verbose_name_plural": "adapter_adapterinstance",
|
||||
"db_table": "adapter_adapterinstance",
|
||||
},
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,18 @@
|
||||
# Generated by Django 4.2.1 on 2024-01-20 08:32
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("adapter_processor", "0001_initial"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddConstraint(
|
||||
model_name="adapterinstance",
|
||||
constraint=models.UniqueConstraint(
|
||||
fields=("adapter_name", "adapter_type"), name="unique_adapter"
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,41 @@
|
||||
# Generated by Django 4.2.1 on 2024-02-13 13:09
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from account.models import EncryptionSecret
|
||||
from adapter_processor.models import AdapterInstance
|
||||
from cryptography.fernet import Fernet
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("adapter_processor", "0002_adapterinstance_unique_adapter"),
|
||||
("account", "0005_encryptionsecret"),
|
||||
]
|
||||
|
||||
def EncryptCredentials(apps: Any, schema_editor: Any) -> None:
|
||||
encryption_secret: EncryptionSecret = EncryptionSecret.objects.get()
|
||||
f: Fernet = Fernet(encryption_secret.key.encode("utf-8"))
|
||||
queryset = AdapterInstance.objects.all()
|
||||
|
||||
for obj in queryset: # type: ignore
|
||||
# Access attributes of the object
|
||||
|
||||
print(f"Object ID: {obj.id}, Name: {obj.adapter_name}")
|
||||
if hasattr(obj, "adapter_metadata"):
|
||||
json_string: str = json.dumps(obj.adapter_metadata)
|
||||
obj.adapter_metadata_b = f.encrypt(json_string.encode("utf-8"))
|
||||
obj.save()
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="adapterinstance",
|
||||
name="adapter_metadata_b",
|
||||
field=models.BinaryField(null=True),
|
||||
),
|
||||
migrations.RunPython(
|
||||
EncryptCredentials, reverse_code=migrations.RunPython.noop
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,26 @@
|
||||
# Generated by Django 4.2.1 on 2024-02-23 09:29
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("adapter_processor", "0003_adapterinstance_adapter_metadata_b"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="adapterinstance",
|
||||
name="adapter_type",
|
||||
field=models.CharField(
|
||||
choices=[
|
||||
("UNKNOWN", "UNKNOWN"),
|
||||
("LLM", "LLM"),
|
||||
("EMBEDDING", "EMBEDDING"),
|
||||
("VECTOR_DB", "VECTOR_DB"),
|
||||
("X2TEXT", "X2TEXT"),
|
||||
],
|
||||
db_comment="Type of adapter LLM/EMBEDDING/VECTOR_DB",
|
||||
),
|
||||
),
|
||||
]
|
||||
0
backend/adapter_processor/migrations/__init__.py
Normal file
0
backend/adapter_processor/migrations/__init__.py
Normal file
78
backend/adapter_processor/models.py
Normal file
78
backend/adapter_processor/models.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import uuid
|
||||
|
||||
from account.models import User
|
||||
from django.db import models
|
||||
from unstract.adapters.enums import AdapterTypes
|
||||
from utils.models.base_model import BaseModel
|
||||
|
||||
ADAPTER_NAME_SIZE = 128
|
||||
VERSION_NAME_SIZE = 64
|
||||
ADAPTER_ID_LENGTH = 128
|
||||
|
||||
|
||||
class AdapterInstance(BaseModel):
|
||||
id = models.UUIDField(
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
editable=False,
|
||||
db_comment="Unique identifier for the Adapter Instance",
|
||||
)
|
||||
adapter_name = models.TextField(
|
||||
max_length=ADAPTER_NAME_SIZE,
|
||||
null=False,
|
||||
blank=False,
|
||||
db_comment="Name of the Adapter Instance",
|
||||
)
|
||||
adapter_id = models.CharField(
|
||||
max_length=ADAPTER_ID_LENGTH,
|
||||
default="",
|
||||
db_comment="Unique identifier of the Adapter",
|
||||
)
|
||||
|
||||
# TODO to be removed once the migration for encryption
|
||||
adapter_metadata = models.JSONField(
|
||||
db_column="adapter_metadata",
|
||||
null=False,
|
||||
blank=False,
|
||||
default=dict,
|
||||
db_comment="JSON adapter metadata submitted by the user",
|
||||
)
|
||||
adapter_metadata_b = models.BinaryField(null=True)
|
||||
adapter_type = models.CharField(
|
||||
choices=[(tag.value, tag.name) for tag in AdapterTypes],
|
||||
db_comment="Type of adapter LLM/EMBEDDING/VECTOR_DB",
|
||||
)
|
||||
created_by = models.ForeignKey(
|
||||
User,
|
||||
on_delete=models.SET_NULL,
|
||||
related_name="created_adapters",
|
||||
null=True,
|
||||
blank=True,
|
||||
)
|
||||
modified_by = models.ForeignKey(
|
||||
User,
|
||||
on_delete=models.SET_NULL,
|
||||
related_name="modified_adapters",
|
||||
null=True,
|
||||
blank=True,
|
||||
)
|
||||
|
||||
is_active = models.BooleanField(
|
||||
default=False,
|
||||
db_comment="Is the adapter instance currently being used",
|
||||
)
|
||||
is_default = models.BooleanField(
|
||||
default=False,
|
||||
db_comment="Is the adapter instance default",
|
||||
)
|
||||
|
||||
class Meta:
|
||||
verbose_name = "adapter_adapterinstance"
|
||||
verbose_name_plural = "adapter_adapterinstance"
|
||||
db_table = "adapter_adapterinstance"
|
||||
constraints = [
|
||||
models.UniqueConstraint(
|
||||
fields=["adapter_name", "adapter_type"],
|
||||
name="unique_adapter",
|
||||
),
|
||||
]
|
||||
91
backend/adapter_processor/serializers.py
Normal file
91
backend/adapter_processor/serializers.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from account.models import EncryptionSecret
|
||||
from adapter_processor.adapter_processor import AdapterProcessor
|
||||
from adapter_processor.constants import AdapterKeys
|
||||
from cryptography.fernet import Fernet
|
||||
from rest_framework import serializers
|
||||
from unstract.adapters.constants import Common as common
|
||||
from utils.serializer_utils import SerializerUtils
|
||||
|
||||
from backend.constants import FieldLengthConstants as FLC
|
||||
from backend.serializers import AuditSerializer
|
||||
|
||||
from .models import AdapterInstance
|
||||
|
||||
|
||||
class TestAdapterSerializer(serializers.Serializer):
|
||||
adapter_id = serializers.CharField(max_length=FLC.ADAPTER_ID_LENGTH)
|
||||
adapter_metadata = serializers.JSONField()
|
||||
adapter_type = serializers.JSONField()
|
||||
|
||||
|
||||
class BaseAdapterSerializer(AuditSerializer):
|
||||
class Meta:
|
||||
model = AdapterInstance
|
||||
fields = "__all__"
|
||||
|
||||
|
||||
class DefaultAdapterSerializer(serializers.Serializer):
|
||||
llm_default = serializers.CharField(
|
||||
max_length=FLC.UUID_LENGTH, required=False
|
||||
)
|
||||
embedding_default = serializers.CharField(
|
||||
max_length=FLC.UUID_LENGTH, required=False
|
||||
)
|
||||
vector_db_default = serializers.CharField(
|
||||
max_length=FLC.UUID_LENGTH, required=False
|
||||
)
|
||||
|
||||
|
||||
class AdapterInstanceSerializer(BaseAdapterSerializer):
|
||||
"""Inherits BaseAdapterSerializer.
|
||||
|
||||
Used for GET/POST request for adapter
|
||||
"""
|
||||
|
||||
class Meta(BaseAdapterSerializer.Meta):
|
||||
pass
|
||||
|
||||
def to_internal_value(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
encryption_secret: EncryptionSecret = EncryptionSecret.objects.get()
|
||||
f: Fernet = Fernet(encryption_secret.key.encode("utf-8"))
|
||||
json_string: str = json.dumps(data.pop(AdapterKeys.ADAPTER_METADATA))
|
||||
|
||||
data[AdapterKeys.ADAPTER_METADATA_B] = f.encrypt(
|
||||
json_string.encode("utf-8")
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
def to_representation(self, instance: AdapterInstance) -> dict[str, str]:
|
||||
rep: dict[str, str] = super().to_representation(instance)
|
||||
|
||||
if SerializerUtils.check_context_for_GET_or_POST(context=self.context):
|
||||
rep[common.ICON] = AdapterProcessor.get_adapter_data_with_key(
|
||||
instance.adapter_id, common.ICON
|
||||
)
|
||||
rep.pop(AdapterKeys.ADAPTER_METADATA_B)
|
||||
return rep
|
||||
|
||||
|
||||
class AdapterDetailSerializer(BaseAdapterSerializer):
|
||||
"""Inherits BaseAdapterSerializer.
|
||||
|
||||
Used for GET/UPDATE/DELETE request for adapter/<uuid:pk>
|
||||
"""
|
||||
|
||||
def to_representation(self, instance: AdapterInstance) -> dict[str, str]:
|
||||
rep: dict[str, str] = super().to_representation(instance)
|
||||
|
||||
encryption_secret: EncryptionSecret = EncryptionSecret.objects.get()
|
||||
f: Fernet = Fernet(encryption_secret.key.encode("utf-8"))
|
||||
|
||||
rep.pop(AdapterKeys.ADAPTER_METADATA_B)
|
||||
adapter_metadata = json.loads(
|
||||
f.decrypt(bytes(instance.adapter_metadata_b).decode("utf-8"))
|
||||
)
|
||||
rep[AdapterKeys.ADAPTER_METADATA] = adapter_metadata
|
||||
|
||||
return rep
|
||||
34
backend/adapter_processor/urls.py
Normal file
34
backend/adapter_processor/urls.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from adapter_processor.views import (
|
||||
AdapterDetailViewSet,
|
||||
AdapterInstanceViewSet,
|
||||
AdapterViewSet,
|
||||
DefaultAdapterViewSet,
|
||||
)
|
||||
from django.urls import path
|
||||
from rest_framework.urlpatterns import format_suffix_patterns
|
||||
|
||||
adapter = AdapterViewSet.as_view({"get": "list"})
|
||||
default_triad = DefaultAdapterViewSet.as_view(
|
||||
{"post": "configure_default_triad"}
|
||||
)
|
||||
adapter_schema = AdapterViewSet.as_view({"get": "get_adapter_schema"})
|
||||
adapter_list = AdapterInstanceViewSet.as_view({"post": "create", "get": "list"})
|
||||
adapter_detail = AdapterDetailViewSet.as_view(
|
||||
{
|
||||
"get": "retrieve",
|
||||
"put": "update",
|
||||
"patch": "partial_update",
|
||||
"delete": "destroy",
|
||||
}
|
||||
)
|
||||
adapter_test = AdapterViewSet.as_view({"post": "test"})
|
||||
urlpatterns = format_suffix_patterns(
|
||||
[
|
||||
path("adapter_schema/", adapter_schema, name="get_adapter_schema"),
|
||||
path("supported_adapters/", adapter, name="adapter-list"),
|
||||
path("adapter/", adapter_list, name="adapter-list"),
|
||||
path("adapter/default_triad/", default_triad, name="default_triad"),
|
||||
path("adapter/<uuid:pk>/", adapter_detail, name="adapter_detail"),
|
||||
path("test_adapters/", adapter_test, name="adapter-test"),
|
||||
]
|
||||
)
|
||||
182
backend/adapter_processor/views.py
Normal file
182
backend/adapter_processor/views.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from account.models import User
|
||||
from adapter_processor.adapter_processor import AdapterProcessor
|
||||
from adapter_processor.constants import AdapterKeys
|
||||
from adapter_processor.exceptions import (
|
||||
CannotDeleteDefaultAdapter,
|
||||
IdIsMandatory,
|
||||
InValidType,
|
||||
UniqueConstraintViolation,
|
||||
)
|
||||
from adapter_processor.serializers import (
|
||||
AdapterDetailSerializer,
|
||||
AdapterInstanceSerializer,
|
||||
DefaultAdapterSerializer,
|
||||
TestAdapterSerializer,
|
||||
)
|
||||
from django.db import IntegrityError
|
||||
from django.db.models import QuerySet
|
||||
from django.http.response import HttpResponse
|
||||
from permissions.permission import IsOwner
|
||||
from rest_framework import status
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.versioning import URLPathVersioning
|
||||
from rest_framework.viewsets import GenericViewSet, ModelViewSet
|
||||
from utils.filtering import FilterHelper
|
||||
|
||||
from .constants import AdapterKeys as constant
|
||||
from .exceptions import InternalServiceError
|
||||
from .models import AdapterInstance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DefaultAdapterViewSet(ModelViewSet):
|
||||
versioning_class = URLPathVersioning
|
||||
serializer_class = DefaultAdapterSerializer
|
||||
|
||||
def configure_default_triad(
|
||||
self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any]
|
||||
) -> HttpResponse:
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
# Convert request data to json
|
||||
default_triad = request.data
|
||||
AdapterProcessor.set_default_triad(default_triad, request.user)
|
||||
return Response(status=status.HTTP_200_OK)
|
||||
|
||||
|
||||
class AdapterViewSet(GenericViewSet):
|
||||
versioning_class = URLPathVersioning
|
||||
serializer_class = TestAdapterSerializer
|
||||
|
||||
def list(
|
||||
self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any]
|
||||
) -> HttpResponse:
|
||||
if request.method == "GET":
|
||||
adapter_type = request.GET.get(AdapterKeys.ADAPTER_TYPE)
|
||||
if (
|
||||
adapter_type == AdapterKeys.LLM
|
||||
or adapter_type == AdapterKeys.EMBEDDING
|
||||
or adapter_type == AdapterKeys.VECTOR_DB
|
||||
or adapter_type == AdapterKeys.X2TEXT
|
||||
):
|
||||
json_schema = AdapterProcessor.get_all_supported_adapters(
|
||||
type=adapter_type
|
||||
)
|
||||
return Response(json_schema, status=status.HTTP_200_OK)
|
||||
else:
|
||||
raise InValidType
|
||||
|
||||
def get_adapter_schema(
|
||||
self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any]
|
||||
) -> HttpResponse:
|
||||
if request.method == "GET":
|
||||
adapter_name = request.GET.get(AdapterKeys.ID)
|
||||
if adapter_name is None or adapter_name == "":
|
||||
raise IdIsMandatory()
|
||||
json_schema = AdapterProcessor.get_json_schema(
|
||||
adapter_id=adapter_name
|
||||
)
|
||||
return Response(data=json_schema, status=status.HTTP_200_OK)
|
||||
|
||||
def test(self, request: Request) -> Response:
|
||||
"""Tests the connector against the credentials passed."""
|
||||
serializer: AdapterInstanceSerializer = self.get_serializer(
|
||||
data=request.data
|
||||
)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
adapter_id = serializer.validated_data.get(AdapterKeys.ADAPTER_ID)
|
||||
adapter_metadata = serializer.validated_data.get(
|
||||
AdapterKeys.ADAPTER_METADATA
|
||||
)
|
||||
adapter_metadata[
|
||||
AdapterKeys.ADAPTER_TYPE
|
||||
] = serializer.validated_data.get(AdapterKeys.ADAPTER_TYPE)
|
||||
test_result = AdapterProcessor.test_adapter(
|
||||
adapter_id=adapter_id, adapter_metadata=adapter_metadata
|
||||
)
|
||||
return Response(
|
||||
{AdapterKeys.IS_VALID: test_result},
|
||||
status=status.HTTP_200_OK,
|
||||
)
|
||||
|
||||
|
||||
class AdapterInstanceViewSet(ModelViewSet):
|
||||
queryset = AdapterInstance.objects.all()
|
||||
|
||||
serializer_class = AdapterInstanceSerializer
|
||||
|
||||
def get_queryset(self) -> Optional[QuerySet]:
|
||||
if filter_args := FilterHelper.build_filter_args(
|
||||
self.request,
|
||||
constant.ADAPTER_TYPE,
|
||||
):
|
||||
queryset = AdapterInstance.objects.filter(
|
||||
created_by=self.request.user, **filter_args
|
||||
)
|
||||
else:
|
||||
queryset = AdapterInstance.objects.filter(
|
||||
created_by=self.request.user
|
||||
)
|
||||
return queryset
|
||||
|
||||
def create(self, request: Any) -> Response:
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
try:
|
||||
# Check to see if there is a default configured
|
||||
# for this adapter_type and for the current user
|
||||
existing_adapter_default = self.get_existing_defaults(
|
||||
request.data, request.user
|
||||
)
|
||||
# If there is no default, then make this one as default
|
||||
if existing_adapter_default is None:
|
||||
# Update the adapter_instance to is_default=True
|
||||
serializer.validated_data[AdapterKeys.IS_DEFAULT] = True
|
||||
|
||||
serializer.save()
|
||||
except IntegrityError:
|
||||
raise UniqueConstraintViolation(
|
||||
f"{AdapterKeys.ADAPTER_NAME_EXISTS}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving adapter to DB: {e}")
|
||||
raise InternalServiceError
|
||||
headers = self.get_success_headers(serializer.data)
|
||||
return Response(
|
||||
serializer.data, status=status.HTTP_201_CREATED, headers=headers
|
||||
)
|
||||
|
||||
def get_existing_defaults(
|
||||
self, adapter_config: dict[str, Any], user: User
|
||||
) -> Optional[AdapterInstance]:
|
||||
filter_params: dict[str, Any] = {}
|
||||
adapter_type = adapter_config.get(AdapterKeys.ADAPTER_TYPE)
|
||||
filter_params["adapter_type"] = adapter_type
|
||||
filter_params["is_default"] = True
|
||||
filter_params["created_by"] = user
|
||||
existing_adapter_default: AdapterInstance = (
|
||||
AdapterInstance.objects.filter(**filter_params).first()
|
||||
)
|
||||
|
||||
return existing_adapter_default
|
||||
|
||||
|
||||
class AdapterDetailViewSet(ModelViewSet):
|
||||
queryset = AdapterInstance.objects.all()
|
||||
serializer_class = AdapterDetailSerializer
|
||||
permission_classes = [IsOwner]
|
||||
|
||||
def destroy(
|
||||
self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any]
|
||||
) -> Response:
|
||||
adapter_instance: AdapterInstance = self.get_object()
|
||||
if adapter_instance.is_default:
|
||||
logger.error("Cannot delete a default adapter")
|
||||
raise CannotDeleteDefaultAdapter
|
||||
super().perform_destroy(adapter_instance)
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
0
backend/api/__init__.py
Normal file
0
backend/api/__init__.py
Normal file
5
backend/api/admin.py
Normal file
5
backend/api/admin.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from django.contrib import admin
|
||||
|
||||
from .models import APIDeployment, APIKey
|
||||
|
||||
admin.site.register([APIDeployment, APIKey])
|
||||
113
backend/api/api_deployment_views.py
Normal file
113
backend/api/api_deployment_views.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from api.deployment_helper import DeploymentHelper
|
||||
from api.exceptions import InvalidAPIRequest
|
||||
from api.models import APIDeployment
|
||||
from api.serializers import (
|
||||
APIDeploymentListSerializer,
|
||||
APIDeploymentSerializer,
|
||||
DeploymentResponseSerializer,
|
||||
ExecutionRequestSerializer,
|
||||
)
|
||||
from django.db.models import QuerySet
|
||||
from permissions.permission import IsOwner
|
||||
from rest_framework import serializers, status, views, viewsets
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.serializers import Serializer
|
||||
from workflow_manager.workflow.dto import ExecutionResponse
|
||||
|
||||
|
||||
class DeploymentExecution(views.APIView):
|
||||
def initialize_request(
|
||||
self, request: Request, *args: Any, **kwargs: Any
|
||||
) -> Request:
|
||||
"""To remove csrf request for public API.
|
||||
|
||||
Args:
|
||||
request (Request): _description_
|
||||
|
||||
Returns:
|
||||
Request: _description_
|
||||
"""
|
||||
setattr(request, "csrf_processing_done", True)
|
||||
return super().initialize_request(request, *args, **kwargs)
|
||||
|
||||
@DeploymentHelper.validate_api_key
|
||||
def post(
|
||||
self, request: Request, org_name: str, api_name: str, api: APIDeployment
|
||||
) -> Response:
|
||||
file_objs = request.FILES.getlist("files")
|
||||
serializer = ExecutionRequestSerializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
timeout = serializer.get_timeout(serializer.validated_data)
|
||||
|
||||
if not file_objs or len(file_objs) == 0:
|
||||
raise InvalidAPIRequest("File shouldn't be empty")
|
||||
response = DeploymentHelper.execute_workflow(
|
||||
organization_name=org_name,
|
||||
api=api,
|
||||
file_objs=file_objs,
|
||||
timeout=timeout,
|
||||
)
|
||||
return Response({"message": response}, status=status.HTTP_200_OK)
|
||||
|
||||
@DeploymentHelper.validate_api_key
|
||||
def get(
|
||||
self, request: Request, org_name: str, api_name: str, api: APIDeployment
|
||||
) -> Response:
|
||||
execution_id = request.query_params.get("execution_id")
|
||||
if not execution_id:
|
||||
raise InvalidAPIRequest("execution_id shouldn't be empty")
|
||||
response: ExecutionResponse = DeploymentHelper.get_execution_status(
|
||||
execution_id=execution_id
|
||||
)
|
||||
return Response(
|
||||
{"status": response.execution_status, "message": response.result},
|
||||
status=status.HTTP_200_OK,
|
||||
)
|
||||
|
||||
|
||||
class APIDeploymentViewSet(viewsets.ModelViewSet):
|
||||
permission_classes = [IsOwner]
|
||||
|
||||
def get_queryset(self) -> Optional[QuerySet]:
|
||||
return APIDeployment.objects.filter(created_by=self.request.user)
|
||||
|
||||
def get_serializer_class(self) -> serializers.Serializer:
|
||||
if self.action in ["list"]:
|
||||
return APIDeploymentListSerializer
|
||||
return APIDeploymentSerializer
|
||||
|
||||
@action(detail=True, methods=["get"])
|
||||
def fetch_one(self, request: Request, pk: Optional[str] = None) -> Response:
|
||||
"""Custom action to fetch a single instance."""
|
||||
instance = self.get_object()
|
||||
serializer = self.get_serializer(instance)
|
||||
return Response(serializer.data)
|
||||
|
||||
def create(
|
||||
self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any]
|
||||
) -> Response:
|
||||
serializer: Serializer = self.get_serializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
self.perform_create(serializer)
|
||||
api_key = DeploymentHelper.create_api_key(serializer=serializer)
|
||||
response_serializer = DeploymentResponseSerializer(
|
||||
{"api_key": api_key.api_key, **serializer.data}
|
||||
)
|
||||
|
||||
headers = self.get_success_headers(serializer.data)
|
||||
return Response(
|
||||
response_serializer.data,
|
||||
status=status.HTTP_201_CREATED,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
def get_error_from_serializer(error_details: dict[str, Any]) -> Optional[str]:
|
||||
error_key = next(iter(error_details))
|
||||
# Get the first error message
|
||||
error_message: str = f"{error_details[error_key][0]} : {error_key}"
|
||||
return error_message
|
||||
28
backend/api/api_key_views.py
Normal file
28
backend/api/api_key_views.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from api.deployment_helper import DeploymentHelper
|
||||
from api.exceptions import APINotFound
|
||||
from api.key_helper import KeyHelper
|
||||
from api.models import APIKey
|
||||
from api.serializers import APIKeyListSerializer, APIKeySerializer
|
||||
from rest_framework import serializers, viewsets
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
|
||||
|
||||
class APIKeyViewSet(viewsets.ModelViewSet):
|
||||
queryset = APIKey.objects.all()
|
||||
|
||||
def get_serializer_class(self) -> serializers.Serializer:
|
||||
if self.action in ["api_keys"]:
|
||||
return APIKeyListSerializer
|
||||
return APIKeySerializer
|
||||
|
||||
@action(detail=True, methods=["get"])
|
||||
def api_keys(self, request: Request, api_id: str) -> Response:
|
||||
"""Custom action to fetch api keys of an api deployment."""
|
||||
api = DeploymentHelper.get_api_by_id(api_id=api_id)
|
||||
if not api:
|
||||
raise APINotFound()
|
||||
keys = KeyHelper.list_api_keys_of_api(api_instance=api)
|
||||
serializer = self.get_serializer(keys, many=True)
|
||||
return Response(serializer.data)
|
||||
5
backend/api/apps.py
Normal file
5
backend/api/apps.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class ApiConfig(AppConfig):
|
||||
name = "api"
|
||||
3
backend/api/constants.py
Normal file
3
backend/api/constants.py
Normal file
@@ -0,0 +1,3 @@
|
||||
class ApiExecution:
|
||||
PATH: str = "deployment/api"
|
||||
MAXIMUM_TIMEOUT_IN_SEC: int = 300 # 5 minutes
|
||||
254
backend/api/deployment_helper.py
Normal file
254
backend/api/deployment_helper.py
Normal file
@@ -0,0 +1,254 @@
|
||||
import logging
|
||||
import uuid
|
||||
from functools import wraps
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from api.constants import ApiExecution
|
||||
from api.exceptions import (
|
||||
ApiKeyCreateException,
|
||||
APINotFound,
|
||||
Forbidden,
|
||||
InactiveAPI,
|
||||
UnauthorizedKey,
|
||||
)
|
||||
from api.key_helper import KeyHelper
|
||||
from api.models import APIDeployment, APIKey
|
||||
from api.serializers import APIExecutionResponseSerializer
|
||||
from django.core.files.uploadedfile import UploadedFile
|
||||
from django.db import connection
|
||||
from django_tenants.utils import get_tenant_model, tenant_context
|
||||
from rest_framework import status
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.serializers import Serializer
|
||||
from rest_framework.utils.serializer_helpers import ReturnDict
|
||||
from workflow_manager.endpoint.destination import DestinationConnector
|
||||
from workflow_manager.endpoint.source import SourceConnector
|
||||
from workflow_manager.workflow.dto import ExecutionResponse
|
||||
from workflow_manager.workflow.models.workflow import Workflow
|
||||
from workflow_manager.workflow.workflow_helper import WorkflowHelper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeploymentHelper:
|
||||
@staticmethod
|
||||
def validate_api_key(func: Any) -> Any:
|
||||
"""Decorator that validates the API key.
|
||||
|
||||
Sample header:
|
||||
Authorization: Bearer 123e4567-e89b-12d3-a456-426614174001
|
||||
Args:
|
||||
func (Any): Function to wrap for validation
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(
|
||||
self: Any, request: Request, *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Wrapper to validate the inputs and key.
|
||||
|
||||
Args:
|
||||
request (Request): Request context
|
||||
|
||||
Raises:
|
||||
Forbidden: _description_
|
||||
APINotFound: _description_
|
||||
|
||||
Returns:
|
||||
Any: _description_
|
||||
"""
|
||||
try:
|
||||
authorization_header = request.headers.get("Authorization")
|
||||
api_key = None
|
||||
if authorization_header and authorization_header.startswith(
|
||||
"Bearer "
|
||||
):
|
||||
api_key = authorization_header.split(" ")[1]
|
||||
if not api_key:
|
||||
raise Forbidden("Missing api key")
|
||||
org_name = kwargs.get("org_name") or request.data.get(
|
||||
"org_name"
|
||||
)
|
||||
api_name = kwargs.get("api_name") or request.data.get(
|
||||
"api_name"
|
||||
)
|
||||
if not api_name:
|
||||
raise Forbidden("Missing api_name")
|
||||
tenant = get_tenant_model().objects.get(schema_name=org_name)
|
||||
with tenant_context(tenant):
|
||||
api_deployment = (
|
||||
DeploymentHelper.get_deployment_by_api_name(
|
||||
api_name=api_name
|
||||
)
|
||||
)
|
||||
DeploymentHelper.validate_api(
|
||||
api_deployment=api_deployment, api_key=api_key
|
||||
)
|
||||
kwargs["api"] = api_deployment
|
||||
return func(self, request, *args, **kwargs)
|
||||
except (UnauthorizedKey, InactiveAPI, APINotFound):
|
||||
raise
|
||||
except Exception as exception:
|
||||
logger.error(f"Exception: {exception}")
|
||||
return Response(
|
||||
{"error": str(exception)}, status=status.HTTP_403_FORBIDDEN
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
@staticmethod
|
||||
def validate_api(
|
||||
api_deployment: Optional[APIDeployment], api_key: str
|
||||
) -> None:
|
||||
"""Validating API and API key.
|
||||
|
||||
Args:
|
||||
api_deployment (Optional[APIDeployment]): _description_
|
||||
api_key (str): _description_
|
||||
|
||||
Raises:
|
||||
APINotFound: _description_
|
||||
InactiveAPI: _description_
|
||||
"""
|
||||
if not api_deployment:
|
||||
raise APINotFound()
|
||||
if not api_deployment.is_active:
|
||||
raise InactiveAPI()
|
||||
KeyHelper.validate_api_key(api_key=api_key, api_instance=api_deployment)
|
||||
|
||||
@staticmethod
|
||||
def validate_and_get_workflow(workflow_id: str) -> Workflow:
|
||||
"""Validate that the specified workflow_id exists in the Workflow
|
||||
model."""
|
||||
return WorkflowHelper.get_workflow_by_id(workflow_id)
|
||||
|
||||
@staticmethod
|
||||
def get_api_by_id(api_id: str) -> Optional[APIDeployment]:
|
||||
try:
|
||||
api_deployment: APIDeployment = APIDeployment.objects.get(pk=api_id)
|
||||
return api_deployment
|
||||
except APIDeployment.DoesNotExist:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def construct_complete_endpoint(api_name: str) -> str:
|
||||
"""Constructs the complete API endpoint by appending organization
|
||||
schema, endpoint path, and Django app backend URL.
|
||||
|
||||
Parameters:
|
||||
- endpoint (str): The endpoint path to be appended to the complete URL.
|
||||
|
||||
Returns:
|
||||
- str: The complete API endpoint URL.
|
||||
"""
|
||||
org_schema = connection.get_tenant().schema_name
|
||||
return f"{ApiExecution.PATH}/{org_schema}/{api_name}/"
|
||||
|
||||
@staticmethod
|
||||
def construct_status_endpoint(api_endpoint: str, execution_id: str) -> str:
|
||||
"""Construct a complete status endpoint URL by appending the
|
||||
execution_id as a query parameter.
|
||||
|
||||
Args:
|
||||
api_endpoint (str): The base API endpoint.
|
||||
execution_id (str): The execution ID to be included as
|
||||
a query parameter.
|
||||
|
||||
Returns:
|
||||
str: The complete status endpoint URL.
|
||||
"""
|
||||
query_parameters = urlencode({"execution_id": execution_id})
|
||||
complete_endpoint = f"{api_endpoint}?{query_parameters}"
|
||||
return complete_endpoint
|
||||
|
||||
@staticmethod
|
||||
def get_deployment_by_api_name(
|
||||
api_name: str,
|
||||
) -> Optional[APIDeployment]:
|
||||
"""Get and return the APIDeployment object by api_name."""
|
||||
try:
|
||||
api: APIDeployment = APIDeployment.objects.get(api_name=api_name)
|
||||
return api
|
||||
except APIDeployment.DoesNotExist:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def create_api_key(serializer: Serializer) -> APIKey:
|
||||
"""To make API key for an API.
|
||||
|
||||
Args:
|
||||
serializer (Serializer): Request serializer
|
||||
|
||||
Raises:
|
||||
ApiKeyCreateException: Exception
|
||||
"""
|
||||
api_deployment: APIDeployment = serializer.instance
|
||||
try:
|
||||
api_key: APIKey = KeyHelper.create_api_key(api_deployment)
|
||||
return api_key
|
||||
except Exception as error:
|
||||
logger.error(f"Error while creating API key error: {str(error)}")
|
||||
api_deployment.delete()
|
||||
logger.info("Deleted the deployment instance")
|
||||
raise ApiKeyCreateException()
|
||||
|
||||
@staticmethod
|
||||
def execute_workflow(
|
||||
organization_name: str,
|
||||
api: APIDeployment,
|
||||
file_objs: list[UploadedFile],
|
||||
timeout: int,
|
||||
) -> ReturnDict:
|
||||
"""Execute workflow by api.
|
||||
|
||||
Args:
|
||||
organization_name (str): organization name
|
||||
api (APIDeployment): api model object
|
||||
file_obj (UploadedFile): input file
|
||||
|
||||
Returns:
|
||||
ReturnDict: execution status/ result
|
||||
"""
|
||||
workflow_id = api.workflow.id
|
||||
pipeline_id = api.id
|
||||
execution_id = str(uuid.uuid4())
|
||||
|
||||
hash_values_of_files = SourceConnector.add_input_file_to_api_storage(
|
||||
workflow_id=workflow_id,
|
||||
execution_id=execution_id,
|
||||
file_objs=file_objs,
|
||||
)
|
||||
try:
|
||||
result = WorkflowHelper.execute_workflow_async(
|
||||
workflow_id=workflow_id,
|
||||
pipeline_id=pipeline_id,
|
||||
hash_values_of_files=hash_values_of_files,
|
||||
timeout=timeout,
|
||||
execution_id=execution_id,
|
||||
)
|
||||
result.status_api = DeploymentHelper.construct_status_endpoint(
|
||||
api_endpoint=api.api_endpoint, execution_id=execution_id
|
||||
)
|
||||
except Exception:
|
||||
DestinationConnector.delete_api_storage_dir(
|
||||
workflow_id=workflow_id, execution_id=execution_id
|
||||
)
|
||||
raise
|
||||
return APIExecutionResponseSerializer(result).data
|
||||
|
||||
@staticmethod
|
||||
def get_execution_status(execution_id: str) -> ExecutionResponse:
|
||||
"""Current status of api execution.
|
||||
|
||||
Args:
|
||||
execution_id (str): execution id
|
||||
|
||||
Returns:
|
||||
ReturnDict: status/result of execution
|
||||
"""
|
||||
execution_response: ExecutionResponse = (
|
||||
WorkflowHelper.get_status_of_async_task(execution_id=execution_id)
|
||||
)
|
||||
return execution_response
|
||||
94
backend/api/exceptions.py
Normal file
94
backend/api/exceptions.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from typing import Optional
|
||||
|
||||
from rest_framework.exceptions import APIException
|
||||
|
||||
|
||||
class MandatoryWorkflowId(APIException):
|
||||
status_code = 400
|
||||
default_detail = "Workflow ID is mandatory"
|
||||
|
||||
|
||||
class ApiKeyCreateException(APIException):
|
||||
status_code = 500
|
||||
default_detail = "Exception while create API key"
|
||||
|
||||
def __init__(
|
||||
self, detail: Optional[str] = None, code: Optional[int] = None
|
||||
):
|
||||
if detail is not None:
|
||||
self.detail = detail
|
||||
if code is not None:
|
||||
self.code = code
|
||||
super().__init__(detail, code)
|
||||
|
||||
|
||||
class Forbidden(APIException):
|
||||
status_code = 403
|
||||
default_detail = (
|
||||
"User is forbidden from performing this action. Please contact admin"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, detail: Optional[str] = None, code: Optional[int] = None
|
||||
):
|
||||
if detail is not None:
|
||||
self.detail = detail
|
||||
if code is not None:
|
||||
self.code = code
|
||||
super().__init__(detail, code)
|
||||
|
||||
|
||||
class APINotFound(APIException):
|
||||
status_code = 404
|
||||
default_detail = "Api not found"
|
||||
|
||||
def __init__(
|
||||
self, detail: Optional[str] = None, code: Optional[int] = None
|
||||
):
|
||||
if detail is not None:
|
||||
self.detail = detail
|
||||
if code is not None:
|
||||
self.code = code
|
||||
super().__init__(detail, code)
|
||||
|
||||
|
||||
class InvalidAPIRequest(APIException):
|
||||
status_code = 400
|
||||
default_detail = "Bad request"
|
||||
|
||||
def __init__(
|
||||
self, detail: Optional[str] = None, code: Optional[int] = None
|
||||
):
|
||||
if detail is not None:
|
||||
self.detail = detail
|
||||
if code is not None:
|
||||
self.code = code
|
||||
super().__init__(detail, code)
|
||||
|
||||
|
||||
class InactiveAPI(APIException):
|
||||
status_code = 404
|
||||
default_detail = "API not found or Inactive"
|
||||
|
||||
def __init__(
|
||||
self, detail: Optional[str] = None, code: Optional[int] = None
|
||||
):
|
||||
if detail is not None:
|
||||
self.detail = detail
|
||||
if code is not None:
|
||||
self.code = code
|
||||
super().__init__(detail, code)
|
||||
|
||||
|
||||
class UnauthorizedKey(APIException):
|
||||
status_code = 401
|
||||
default_detail = "Unauthorized"
|
||||
|
||||
def __init__(
|
||||
self, detail: Optional[str] = None, code: Optional[int] = None
|
||||
):
|
||||
if detail is not None:
|
||||
self.detail = detail
|
||||
if code is not None:
|
||||
self.code = code
|
||||
super().__init__(detail, code)
|
||||
72
backend/api/key_helper.py
Normal file
72
backend/api/key_helper.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import logging
|
||||
|
||||
from api.exceptions import Forbidden, UnauthorizedKey
|
||||
from api.models import APIDeployment, APIKey
|
||||
from api.serializers import APIKeySerializer
|
||||
from workflow_manager.workflow.workflow_helper import WorkflowHelper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KeyHelper:
|
||||
@staticmethod
|
||||
def validate_api_key(api_key: str, api_instance: APIDeployment) -> None:
|
||||
"""Validate api key.
|
||||
|
||||
Args:
|
||||
api_key (str): api key from request
|
||||
api_instance (APIDeployment): api deployment instance
|
||||
|
||||
Raises:
|
||||
Forbidden: _description_
|
||||
"""
|
||||
try:
|
||||
api_key_instance: APIKey = APIKey.objects.get(api_key=api_key)
|
||||
if not KeyHelper.has_access(api_key_instance, api_instance):
|
||||
raise UnauthorizedKey()
|
||||
except APIKey.DoesNotExist:
|
||||
raise UnauthorizedKey()
|
||||
except APIDeployment.DoesNotExist:
|
||||
raise Forbidden("API not found.")
|
||||
|
||||
@staticmethod
|
||||
def list_api_keys_of_api(api_instance: APIDeployment) -> list[APIKey]:
|
||||
api_keys: list[APIKey] = APIKey.objects.filter(api=api_instance).all()
|
||||
return api_keys
|
||||
|
||||
@staticmethod
|
||||
def has_access(api_key: APIKey, api_instance: APIDeployment) -> bool:
|
||||
"""Check if the provided API key has access to the specified API
|
||||
instance.
|
||||
|
||||
Args:
|
||||
api_key (APIKey): api key associated with the api
|
||||
api_instance (APIDeployment): api model
|
||||
|
||||
Returns:
|
||||
bool: True if allowed to execute, False otherwise
|
||||
"""
|
||||
if not api_key.is_active:
|
||||
return False
|
||||
if isinstance(api_key.api, APIDeployment):
|
||||
return api_key.api == api_instance
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def validate_workflow_exists(workflow_id: str) -> None:
|
||||
"""Validate that the specified workflow_id exists in the Workflow
|
||||
model."""
|
||||
WorkflowHelper.get_workflow_by_id(workflow_id)
|
||||
|
||||
@staticmethod
|
||||
def create_api_key(deployment: APIDeployment) -> APIKey:
|
||||
"""Create an APIKey entity with the data from the provided
|
||||
APIDeployment instance."""
|
||||
# Create an instance of the APIKey model
|
||||
api_key_serializer = APIKeySerializer(
|
||||
data={"api": deployment.id, "description": "Initial Access Key"},
|
||||
context={"deployment": deployment},
|
||||
)
|
||||
api_key_serializer.is_valid(raise_exception=True)
|
||||
api_key: APIKey = api_key_serializer.save()
|
||||
return api_key
|
||||
185
backend/api/migrations/0001_initial.py
Normal file
185
backend/api/migrations/0001_initial.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# Generated by Django 4.2.1 on 2024-01-23 11:18
|
||||
|
||||
import uuid
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||
("workflow", "0001_initial"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="APIDeployment",
|
||||
fields=[
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("modified_at", models.DateTimeField(auto_now=True)),
|
||||
(
|
||||
"id",
|
||||
models.UUIDField(
|
||||
default=uuid.uuid4,
|
||||
editable=False,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
),
|
||||
),
|
||||
(
|
||||
"display_name",
|
||||
models.CharField(
|
||||
db_comment="User-given display name for the API.",
|
||||
default="default api",
|
||||
max_length=30,
|
||||
unique=True,
|
||||
),
|
||||
),
|
||||
(
|
||||
"description",
|
||||
models.CharField(
|
||||
blank=True,
|
||||
db_comment="User-given description for the API.",
|
||||
default="",
|
||||
max_length=255,
|
||||
),
|
||||
),
|
||||
(
|
||||
"is_active",
|
||||
models.BooleanField(
|
||||
db_comment="Flag indicating whether the API is active or not.",
|
||||
default=True,
|
||||
),
|
||||
),
|
||||
(
|
||||
"api_endpoint",
|
||||
models.CharField(
|
||||
db_comment="URL endpoint for the API deployment.",
|
||||
editable=False,
|
||||
max_length=255,
|
||||
unique=True,
|
||||
),
|
||||
),
|
||||
(
|
||||
"api_name",
|
||||
models.CharField(
|
||||
db_comment="Short name for the API deployment.",
|
||||
default="default",
|
||||
max_length=30,
|
||||
unique=True,
|
||||
),
|
||||
),
|
||||
(
|
||||
"created_by",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
editable=False,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
related_name="api_created_by",
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
(
|
||||
"modified_by",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
editable=False,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
related_name="api_modified_by",
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
(
|
||||
"workflow",
|
||||
models.ForeignKey(
|
||||
db_comment="Foreign key reference to the Workflow model.",
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="workflow.workflow",
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="APIKey",
|
||||
fields=[
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("modified_at", models.DateTimeField(auto_now=True)),
|
||||
(
|
||||
"id",
|
||||
models.UUIDField(
|
||||
db_comment="Unique identifier for the API key.",
|
||||
default=uuid.uuid4,
|
||||
editable=False,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
),
|
||||
),
|
||||
(
|
||||
"api_key",
|
||||
models.UUIDField(
|
||||
db_comment="Actual key UUID.",
|
||||
default=uuid.uuid4,
|
||||
editable=False,
|
||||
unique=True,
|
||||
),
|
||||
),
|
||||
(
|
||||
"description",
|
||||
models.CharField(
|
||||
db_comment="Description of the API key.",
|
||||
max_length=255,
|
||||
null=True,
|
||||
),
|
||||
),
|
||||
(
|
||||
"is_active",
|
||||
models.BooleanField(
|
||||
db_comment="Flag indicating whether the API key is active or not.",
|
||||
default=True,
|
||||
),
|
||||
),
|
||||
(
|
||||
"api",
|
||||
models.ForeignKey(
|
||||
db_comment="Foreign key reference to the APIDeployment model.",
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="api.apideployment",
|
||||
),
|
||||
),
|
||||
(
|
||||
"created_by",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
editable=False,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
related_name="api_key_created_by",
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
(
|
||||
"modified_by",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
editable=False,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
related_name="api_key_modified_by",
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
]
|
||||
0
backend/api/migrations/__init__.py
Normal file
0
backend/api/migrations/__init__.py
Normal file
141
backend/api/models.py
Normal file
141
backend/api/models.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from account.models import User
|
||||
from api.constants import ApiExecution
|
||||
from django.db import connection, models
|
||||
from utils.models.base_model import BaseModel
|
||||
from workflow_manager.workflow.models.workflow import Workflow
|
||||
|
||||
API_NAME_MAX_LENGTH = 30
|
||||
DESCRIPTION_MAX_LENGTH = 255
|
||||
API_ENDPOINT_MAX_LENGTH = 255
|
||||
|
||||
|
||||
class APIDeployment(BaseModel):
|
||||
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
|
||||
display_name = models.CharField(
|
||||
max_length=API_NAME_MAX_LENGTH,
|
||||
unique=True,
|
||||
default="default api",
|
||||
db_comment="User-given display name for the API.",
|
||||
)
|
||||
description = models.CharField(
|
||||
max_length=DESCRIPTION_MAX_LENGTH,
|
||||
blank=True,
|
||||
default="",
|
||||
db_comment="User-given description for the API.",
|
||||
)
|
||||
workflow = models.ForeignKey(
|
||||
Workflow,
|
||||
on_delete=models.CASCADE,
|
||||
db_comment="Foreign key reference to the Workflow model.",
|
||||
)
|
||||
is_active = models.BooleanField(
|
||||
default=True,
|
||||
db_comment="Flag indicating whether the API is active or not.",
|
||||
)
|
||||
api_endpoint = models.CharField(
|
||||
max_length=API_ENDPOINT_MAX_LENGTH,
|
||||
unique=True,
|
||||
editable=False,
|
||||
db_comment="URL endpoint for the API deployment.",
|
||||
)
|
||||
api_name = models.CharField(
|
||||
max_length=API_NAME_MAX_LENGTH,
|
||||
unique=True,
|
||||
default="default",
|
||||
db_comment="Short name for the API deployment.",
|
||||
)
|
||||
created_by = models.ForeignKey(
|
||||
User,
|
||||
on_delete=models.SET_NULL,
|
||||
related_name="api_created_by",
|
||||
null=True,
|
||||
blank=True,
|
||||
editable=False,
|
||||
)
|
||||
modified_by = models.ForeignKey(
|
||||
User,
|
||||
on_delete=models.SET_NULL,
|
||||
related_name="api_modified_by",
|
||||
null=True,
|
||||
blank=True,
|
||||
editable=False,
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.id} - {self.display_name}"
|
||||
|
||||
def save(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Save hook to update api_endpoint.
|
||||
|
||||
Custom save hook for updating the 'api_endpoint' based on
|
||||
'api_name'. If the instance is being updated, it checks for
|
||||
changes in 'api_name' and adjusts 'api_endpoint'
|
||||
accordingly. If the instance is new, 'api_endpoint' is set
|
||||
based on 'api_name' and the current database schema.
|
||||
"""
|
||||
if self.pk is not None:
|
||||
try:
|
||||
original = APIDeployment.objects.get(pk=self.pk)
|
||||
if original.api_name != self.api_name:
|
||||
org_schema = connection.get_tenant().schema_name
|
||||
self.api_endpoint = (
|
||||
f"{ApiExecution.PATH}/{org_schema}/{self.api_name}/"
|
||||
)
|
||||
except APIDeployment.DoesNotExist:
|
||||
org_schema = connection.get_tenant().schema_name
|
||||
|
||||
self.api_endpoint = (
|
||||
f"{ApiExecution.PATH}/{org_schema}/{self.api_name}/"
|
||||
)
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
|
||||
class APIKey(BaseModel):
|
||||
id = models.UUIDField(
|
||||
primary_key=True,
|
||||
editable=False,
|
||||
default=uuid.uuid4,
|
||||
db_comment="Unique identifier for the API key.",
|
||||
)
|
||||
api_key = models.UUIDField(
|
||||
default=uuid.uuid4,
|
||||
editable=False,
|
||||
unique=True,
|
||||
db_comment="Actual key UUID.",
|
||||
)
|
||||
api = models.ForeignKey(
|
||||
APIDeployment,
|
||||
on_delete=models.CASCADE,
|
||||
db_comment="Foreign key reference to the APIDeployment model.",
|
||||
)
|
||||
description = models.CharField(
|
||||
max_length=DESCRIPTION_MAX_LENGTH,
|
||||
null=True,
|
||||
db_comment="Description of the API key.",
|
||||
)
|
||||
is_active = models.BooleanField(
|
||||
default=True,
|
||||
db_comment="Flag indicating whether the API key is active or not.",
|
||||
)
|
||||
created_by = models.ForeignKey(
|
||||
User,
|
||||
on_delete=models.SET_NULL,
|
||||
related_name="api_key_created_by",
|
||||
null=True,
|
||||
blank=True,
|
||||
editable=False,
|
||||
)
|
||||
modified_by = models.ForeignKey(
|
||||
User,
|
||||
on_delete=models.SET_NULL,
|
||||
related_name="api_key_modified_by",
|
||||
null=True,
|
||||
blank=True,
|
||||
editable=False,
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.api.api_name} - {self.id} - {self.api_key}"
|
||||
124
backend/api/serializers.py
Normal file
124
backend/api/serializers.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Union
|
||||
|
||||
from api.constants import ApiExecution
|
||||
from api.models import APIDeployment, APIKey
|
||||
from backend.serializers import AuditSerializer
|
||||
from django.core.validators import RegexValidator
|
||||
from rest_framework.serializers import (
|
||||
CharField,
|
||||
IntegerField,
|
||||
JSONField,
|
||||
ModelSerializer,
|
||||
Serializer,
|
||||
ValidationError,
|
||||
)
|
||||
|
||||
|
||||
class APIDeploymentSerializer(AuditSerializer):
|
||||
class Meta:
|
||||
model = APIDeployment
|
||||
fields = "__all__"
|
||||
|
||||
def validate_api_name(self, value: str) -> str:
|
||||
api_name_validator = RegexValidator(
|
||||
regex=r"^[a-zA-Z0-9_-]+$",
|
||||
message="Only letters, numbers, hyphen and \
|
||||
underscores are allowed.",
|
||||
code="invalid_api_name",
|
||||
)
|
||||
api_name_validator(value)
|
||||
return value
|
||||
|
||||
|
||||
class APIKeySerializer(AuditSerializer):
|
||||
class Meta:
|
||||
model = APIKey
|
||||
fields = "__all__"
|
||||
|
||||
def to_representation(self, instance: APIKey) -> OrderedDict[str, Any]:
|
||||
"""Override the to_representation method to include additional
|
||||
context."""
|
||||
context = self.context.get("context", {})
|
||||
deployment: APIDeployment = context.get("deployment")
|
||||
representation: OrderedDict[str, Any] = super().to_representation(
|
||||
instance
|
||||
)
|
||||
if deployment:
|
||||
representation["api"] = deployment.id
|
||||
representation["description"] = f"API Key for {deployment.name}"
|
||||
representation["is_active"] = True
|
||||
|
||||
return representation
|
||||
|
||||
|
||||
class ExecutionRequestSerializer(Serializer):
|
||||
"""Execution request serializer
|
||||
timeout: 0: maximum value of timeout, -1: async execution
|
||||
"""
|
||||
|
||||
timeout = IntegerField(
|
||||
min_value=-1, max_value=ApiExecution.MAXIMUM_TIMEOUT_IN_SEC, default=-1
|
||||
)
|
||||
|
||||
def validate_timeout(self, value: Any) -> int:
|
||||
if not isinstance(value, int):
|
||||
raise ValidationError("timeout must be a integer.")
|
||||
if value == 0:
|
||||
value = ApiExecution.MAXIMUM_TIMEOUT_IN_SEC
|
||||
return value
|
||||
|
||||
def get_timeout(self, validated_data: dict[str, Union[int, None]]) -> int:
|
||||
value = validated_data.get("timeout", -1)
|
||||
if not isinstance(value, int):
|
||||
raise ValidationError("timeout must be a integer.")
|
||||
return value
|
||||
|
||||
|
||||
class APIDeploymentListSerializer(ModelSerializer):
|
||||
workflow_name = CharField(source="workflow.workflow_name", read_only=True)
|
||||
|
||||
class Meta:
|
||||
model = APIDeployment
|
||||
fields = [
|
||||
"id",
|
||||
"workflow",
|
||||
"workflow_name",
|
||||
"display_name",
|
||||
"description",
|
||||
"is_active",
|
||||
"api_endpoint",
|
||||
"api_name",
|
||||
"created_by",
|
||||
]
|
||||
|
||||
|
||||
class APIKeyListSerializer(ModelSerializer):
|
||||
class Meta:
|
||||
model = APIKey
|
||||
fields = [
|
||||
"id",
|
||||
"created_at",
|
||||
"modified_at",
|
||||
"api_key",
|
||||
"is_active",
|
||||
"description",
|
||||
"api",
|
||||
]
|
||||
|
||||
|
||||
class DeploymentResponseSerializer(Serializer):
|
||||
is_active = CharField()
|
||||
id = CharField()
|
||||
api_key = CharField()
|
||||
api_endpoint = CharField()
|
||||
display_name = CharField()
|
||||
description = CharField()
|
||||
api_name = CharField()
|
||||
|
||||
|
||||
class APIExecutionResponseSerializer(Serializer):
|
||||
execution_status = CharField()
|
||||
status_api = CharField()
|
||||
error = CharField()
|
||||
result = JSONField()
|
||||
1
backend/api/tests.py
Normal file
1
backend/api/tests.py
Normal file
@@ -0,0 +1 @@
|
||||
# Create your tests here.
|
||||
52
backend/api/urls.py
Normal file
52
backend/api/urls.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from api.api_deployment_views import APIDeploymentViewSet, DeploymentExecution
|
||||
from api.api_key_views import APIKeyViewSet
|
||||
from django.urls import path, re_path
|
||||
from rest_framework.urlpatterns import format_suffix_patterns
|
||||
|
||||
deployment = APIDeploymentViewSet.as_view(
|
||||
{
|
||||
"get": APIDeploymentViewSet.list.__name__,
|
||||
"post": APIDeploymentViewSet.create.__name__,
|
||||
}
|
||||
)
|
||||
deployment_details = APIDeploymentViewSet.as_view(
|
||||
{
|
||||
"get": APIDeploymentViewSet.retrieve.__name__,
|
||||
"put": APIDeploymentViewSet.update.__name__,
|
||||
"patch": APIDeploymentViewSet.partial_update.__name__,
|
||||
"delete": APIDeploymentViewSet.destroy.__name__,
|
||||
}
|
||||
)
|
||||
execute = DeploymentExecution.as_view()
|
||||
|
||||
key_details = APIKeyViewSet.as_view(
|
||||
{
|
||||
"get": APIKeyViewSet.retrieve.__name__,
|
||||
"put": APIKeyViewSet.update.__name__,
|
||||
"delete": APIKeyViewSet.destroy.__name__,
|
||||
}
|
||||
)
|
||||
api_key = APIKeyViewSet.as_view(
|
||||
{
|
||||
"get": APIKeyViewSet.api_keys.__name__,
|
||||
"post": APIKeyViewSet.create.__name__,
|
||||
}
|
||||
)
|
||||
|
||||
urlpatterns = format_suffix_patterns(
|
||||
[
|
||||
path("deployment/", deployment, name="api_deployment"),
|
||||
path(
|
||||
"deployment/<uuid:pk>/",
|
||||
deployment_details,
|
||||
name="api_deployment_details",
|
||||
),
|
||||
re_path(
|
||||
r"^api/(?P<org_name>[\w-]+)/(?P<api_name>[\w-]+)/?$",
|
||||
execute,
|
||||
name="api_deployment_execution",
|
||||
),
|
||||
path("keys/<uuid:pk>/", key_details, name="key_details"),
|
||||
path("keys/api/<str:api_id>/", api_key, name="api_key"),
|
||||
]
|
||||
)
|
||||
0
backend/apps/__init__.py
Normal file
0
backend/apps/__init__.py
Normal file
4
backend/apps/constants.py
Normal file
4
backend/apps/constants.py
Normal file
@@ -0,0 +1,4 @@
|
||||
class AppConstants:
|
||||
"""Constants for Apps."""
|
||||
|
||||
|
||||
6
backend/apps/exceptions.py
Normal file
6
backend/apps/exceptions.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from rest_framework.exceptions import APIException
|
||||
|
||||
|
||||
class FetchAppListFailed(APIException):
|
||||
status_code = 400
|
||||
default_detail = "Failed to fetch App list."
|
||||
9
backend/apps/urls.py
Normal file
9
backend/apps/urls.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from django.urls import path
|
||||
from apps import views
|
||||
from rest_framework.urlpatterns import format_suffix_patterns
|
||||
|
||||
urlpatterns = format_suffix_patterns(
|
||||
[
|
||||
path("app/", views.get_app_list, name="app-list"),
|
||||
]
|
||||
)
|
||||
22
backend/apps/views.py
Normal file
22
backend/apps/views.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import logging
|
||||
|
||||
from apps.exceptions import FetchAppListFailed
|
||||
from rest_framework import status
|
||||
from rest_framework.decorators import api_view
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@api_view(("GET",))
|
||||
def get_app_list(request: Request) -> Response:
|
||||
"""API to fetch List of Apps."""
|
||||
if request.method == "GET":
|
||||
try:
|
||||
return Response(data=[], status=status.HTTP_200_OK)
|
||||
# Refactored dated: 19/12/2023
|
||||
# ( Removed -> backend/apps/app_processor.py )
|
||||
except Exception as exe:
|
||||
logger.error(f"Error occured while fetching app list {exe}")
|
||||
raise FetchAppListFailed()
|
||||
3
backend/backend/__init__.py
Normal file
3
backend/backend/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .celery import app as celery_app
|
||||
|
||||
__all__ = ["celery_app"]
|
||||
20
backend/backend/asgi.py
Normal file
20
backend/backend/asgi.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""ASGI config for backend project.
|
||||
|
||||
It exposes the ASGI callable as a module-level variable named ``application``.
|
||||
|
||||
For more information on this file, see
|
||||
https://docs.djangoproject.com/en/4.2/howto/deployment/asgi/
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from django.core.asgi import get_asgi_application
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
os.environ.setdefault(
|
||||
"DJANGO_SETTINGS_MODULE",
|
||||
os.environ.get("DJANGO_SETTINGS_MODULE", "backend.settings.dev"),
|
||||
)
|
||||
|
||||
application = get_asgi_application()
|
||||
29
backend/backend/celery.py
Normal file
29
backend/backend/celery.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""This module contains the Celery configuration for the backend
|
||||
project."""
|
||||
|
||||
import os
|
||||
|
||||
from celery import Celery
|
||||
from django.conf import settings
|
||||
|
||||
# Set the default Django settings module for the 'celery' program.
|
||||
os.environ.setdefault(
|
||||
"DJANGO_SETTINGS_MODULE",
|
||||
os.environ.get("DJANGO_SETTINGS_MODULE", "backend.settings.dev"),
|
||||
)
|
||||
|
||||
# Create a Celery instance. Default time zone is UTC.
|
||||
app = Celery("backend")
|
||||
|
||||
# Use Redis as the message broker.
|
||||
app.conf.broker_url = settings.CELERY_BROKER_URL
|
||||
app.conf.result_backend = settings.CELERY_RESULT_BACKEND
|
||||
|
||||
# Load task modules from all registered Django app configs.
|
||||
app.config_from_object("django.conf:settings", namespace="CELERY")
|
||||
|
||||
# Autodiscover tasks in all installed apps.
|
||||
app.autodiscover_tasks()
|
||||
|
||||
# Use the Django-Celery-Beat scheduler.
|
||||
app.conf.beat_scheduler = "django_celery_beat.schedulers:DatabaseScheduler"
|
||||
30
backend/backend/constants.py
Normal file
30
backend/backend/constants.py
Normal file
@@ -0,0 +1,30 @@
|
||||
class RequestKey:
|
||||
"""Commonly used keys in requests/repsonses."""
|
||||
|
||||
REQUEST = "request"
|
||||
PROJECT = "project"
|
||||
WORKFLOW = "workflow"
|
||||
CREATED_BY = "created_by"
|
||||
MODIFIED_BY = "modified_by"
|
||||
MODIFIED_AT = "modified_at"
|
||||
|
||||
|
||||
class FieldLengthConstants:
|
||||
"""Used to determine length of fields in a model."""
|
||||
|
||||
ORG_NAME_SIZE = 64
|
||||
CRON_LENGTH = 256
|
||||
UUID_LENGTH = 36
|
||||
# Not to be confused with a connector instance
|
||||
CONNECTOR_ID_LENGTH = 128
|
||||
ADAPTER_ID_LENGTH = 128
|
||||
|
||||
|
||||
class RequestHeader:
|
||||
"""Request header constants."""
|
||||
|
||||
X_API_KEY = "X-API-KEY"
|
||||
|
||||
|
||||
class UrlPathConstants:
|
||||
PROMPT_STUDIO = "prompt-studio/"
|
||||
36
backend/backend/exceptions.py
Normal file
36
backend/backend/exceptions.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from rest_framework.exceptions import APIException
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import exception_handler
|
||||
from unstract.connectors.exceptions import ConnectorBaseException
|
||||
|
||||
|
||||
class UnstractBaseException(APIException):
|
||||
default_detail = "Error occurred"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
detail: Optional[str] = None,
|
||||
core_err: Optional[ConnectorBaseException] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if detail is None:
|
||||
detail = self.default_detail
|
||||
if core_err and core_err.user_message:
|
||||
detail = core_err.user_message
|
||||
super().__init__(detail=detail, **kwargs)
|
||||
self._core_err = core_err
|
||||
|
||||
|
||||
class LLMHelperError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def custom_exception_handler(exc, context) -> Response: # type: ignore
|
||||
response = exception_handler(exc, context)
|
||||
|
||||
if response is not None:
|
||||
response.data["status_code"] = response.status_code
|
||||
|
||||
return response
|
||||
15
backend/backend/flowerconfig.py
Normal file
15
backend/backend/flowerconfig.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Flower is a real-time web based monitor and administration tool
|
||||
# for Celery. It’s under active development,
|
||||
# but is already an essential tool.
|
||||
from django.conf import settings
|
||||
|
||||
# Broker URL
|
||||
BROKER_URL = settings.CELERY_BROKER_URL
|
||||
|
||||
# Flower web port
|
||||
PORT = 5555
|
||||
|
||||
# Enable basic authentication (when required)
|
||||
# basic_auth = {
|
||||
# 'username': 'password'
|
||||
# }
|
||||
40
backend/backend/public_urls.py
Normal file
40
backend/backend/public_urls.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""URL configuration for backend project.
|
||||
|
||||
The `urlpatterns` list routes URLs to views. For more information please see:
|
||||
https://docs.djangoproject.com/en/4.2/topics/http/urls/
|
||||
Examples:
|
||||
Function views
|
||||
1. Add an import: from my_app import views
|
||||
2. Add a URL to urlpatterns: path('', views.home, name='home')
|
||||
Class-based views
|
||||
1. Add an import: from other_app.views import Home
|
||||
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
|
||||
Including another URLconf
|
||||
1. Import the include() function: from django.urls import include, path
|
||||
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
|
||||
"""
|
||||
|
||||
from account.admin import admin
|
||||
from django.conf import settings
|
||||
from django.conf.urls import * # noqa: F401, F403
|
||||
from django.urls import include, path
|
||||
|
||||
path_prefix = settings.PATH_PREFIX
|
||||
api_path_prefix = settings.API_DEPLOYMENT_PATH_PREFIX
|
||||
|
||||
urlpatterns = [
|
||||
path(f"{path_prefix}/", include("account.urls")),
|
||||
# Admin URLs
|
||||
path(f"{path_prefix}/admin/doc/", include("django.contrib.admindocs.urls")),
|
||||
path(f"{path_prefix}/admin/", admin.site.urls),
|
||||
# Connector OAuth
|
||||
path(f"{path_prefix}/", include("connector_auth.urls")),
|
||||
# Docs
|
||||
path(f"{path_prefix}/", include("docs.urls")),
|
||||
# Socket.io
|
||||
path(f"{path_prefix}/", include("log_events.urls")),
|
||||
# API deployment
|
||||
path(f"{api_path_prefix}/", include("api.urls")),
|
||||
# Feature flags
|
||||
path(f"{path_prefix}/flags/", include("feature_flag.urls")),
|
||||
]
|
||||
23
backend/backend/serializers.py
Normal file
23
backend/backend/serializers.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Any
|
||||
|
||||
from backend.constants import RequestKey
|
||||
from rest_framework.serializers import ModelSerializer
|
||||
|
||||
|
||||
class AuditSerializer(ModelSerializer):
|
||||
def create(self, validated_data: dict[str, Any]) -> Any:
|
||||
if self.context.get(RequestKey.REQUEST):
|
||||
validated_data[RequestKey.CREATED_BY] = self.context.get(
|
||||
RequestKey.REQUEST
|
||||
).user
|
||||
validated_data[RequestKey.MODIFIED_BY] = self.context.get(
|
||||
RequestKey.REQUEST
|
||||
).user
|
||||
return super().create(validated_data)
|
||||
|
||||
def update(self, instance: Any, validated_data: dict[str, Any]) -> Any:
|
||||
if self.context.get(RequestKey.REQUEST):
|
||||
validated_data[RequestKey.MODIFIED_BY] = self.context.get(
|
||||
RequestKey.REQUEST
|
||||
).user
|
||||
return super().update(instance, validated_data)
|
||||
0
backend/backend/settings/__init__.py
Normal file
0
backend/backend/settings/__init__.py
Normal file
458
backend/backend/settings/base.py
Normal file
458
backend/backend/settings/base.py
Normal file
@@ -0,0 +1,458 @@
|
||||
"""Django settings for backend project.
|
||||
|
||||
Generated by 'django-admin startproject' using Django 4.2.1.
|
||||
|
||||
For more information on this file, see
|
||||
https://docs.djangoproject.com/en/4.2/topics/settings/
|
||||
|
||||
For the full list of settings and their values, see
|
||||
https://docs.djangoproject.com/en/4.2/ref/settings/
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
|
||||
missing_settings = []
|
||||
|
||||
|
||||
def get_required_setting(
|
||||
setting_key: str, default: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""Get the value of an environment variable specified by the given key. Add
|
||||
missing keys to `missing_settings` so that exception can be raised at the
|
||||
end.
|
||||
|
||||
Args:
|
||||
key (str): The key of the environment variable
|
||||
default (Optional[str], optional): Default value to return incase of
|
||||
env not found. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The value of the environment variable if found,
|
||||
otherwise the default value.
|
||||
"""
|
||||
data = os.environ.get(setting_key, default)
|
||||
if not data:
|
||||
missing_settings.append(setting_key)
|
||||
return data
|
||||
|
||||
|
||||
# Build paths inside the project like this: BASE_DIR / 'subdir'.
|
||||
BASE_DIR = Path(__file__).resolve().parent.parent
|
||||
|
||||
LOGGING = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"verbose": {
|
||||
"format": "[%(asctime)s] %(levelname)s %(name)s: %(message)s",
|
||||
"datefmt": "%d/%b/%Y %H:%M:%S",
|
||||
},
|
||||
"simple": {
|
||||
"format": "{levelname} {message}",
|
||||
"style": "{",
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"level": "INFO", # Set the desired logging level here
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "verbose",
|
||||
},
|
||||
},
|
||||
"root": {
|
||||
"handlers": ["console"],
|
||||
"level": "INFO", # Set the desired logging level here as well
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
ENV_FILE = find_dotenv()
|
||||
if ENV_FILE:
|
||||
load_dotenv(ENV_FILE)
|
||||
|
||||
# Loading environment variables
|
||||
|
||||
WORKFLOW_ACTION_EXPIRATION_TIME_IN_SECOND = os.environ.get(
|
||||
"WORKFLOW_ACTION_EXPIRATION_TIME_IN_SECOND", 10800
|
||||
)
|
||||
WEB_APP_ORIGIN_URL = os.environ.get(
|
||||
"WEB_APP_ORIGIN_URL", "http://localhost:3000"
|
||||
)
|
||||
|
||||
LOGIN_NEXT_URL = os.environ.get("LOGIN_NEXT_URL", "http://localhost:3000/org")
|
||||
LANDING_URL = os.environ.get("LANDING_URL", "http://localhost:3000/landing")
|
||||
ERROR_URL = os.environ.get("ERROR_URL", "http://localhost:3000/error")
|
||||
|
||||
DJANGO_APP_BACKEND_URL = os.environ.get(
|
||||
"DJANGO_APP_BACKEND_URL", "http://localhost:8000"
|
||||
)
|
||||
INTERNAL_SERVICE_API_KEY = os.environ.get("INTERNAL_SERVICE_API_KEY")
|
||||
|
||||
GOOGLE_STORAGE_ACCESS_KEY_ID = os.environ.get("GOOGLE_STORAGE_ACCESS_KEY_ID")
|
||||
GOOGLE_STORAGE_SECRET_ACCESS_KEY = os.environ.get(
|
||||
"GOOGLE_STORAGE_SECRET_ACCESS_KEY"
|
||||
)
|
||||
UNSTRACT_FREE_STORAGE_BUCKET_NAME = os.environ.get(
|
||||
"UNSTRACT_FREE_STORAGE_BUCKET_NAME", "pandora-user-storage"
|
||||
)
|
||||
GOOGLE_STORAGE_BASE_URL = os.environ.get("GOOGLE_STORAGE_BASE_URL")
|
||||
REDIS_USER = os.environ.get("REDIS_USER", "default")
|
||||
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD", "")
|
||||
REDIS_HOST = os.environ.get("REDIS_HOST", "localhost")
|
||||
REDIS_PORT = os.environ.get("REDIS_PORT", "6379")
|
||||
REDIS_DB = os.environ.get("REDIS_DB", "")
|
||||
SESSION_EXPIRATION_TIME_IN_SECOND = os.environ.get(
|
||||
"SESSION_EXPIRATION_TIME_IN_SECOND", 3600
|
||||
)
|
||||
|
||||
PATH_PREFIX = os.environ.get("PATH_PREFIX", "api/v1").strip("/")
|
||||
API_DEPLOYMENT_PATH_PREFIX = os.environ.get(
|
||||
"API_DEPLOYMENT_PATH_PREFIX", "deployment"
|
||||
).strip("/")
|
||||
|
||||
DB_NAME = os.environ.get("DB_NAME", "unstract_db")
|
||||
DB_USER = os.environ.get("DB_USER", "unstract_dev")
|
||||
DB_HOST = os.environ.get("DB_HOST", "backend-db-1")
|
||||
DB_PASSWORD = os.environ.get("DB_PASSWORD", "unstract_pass")
|
||||
DB_PORT = os.environ.get("DB_PORT", 5432)
|
||||
|
||||
DEFAULT_ORGANIZATION = "default_org"
|
||||
FLIPT_BASE_URL = os.environ.get("FLIPT_BASE_URL", "http://localhost:9005")
|
||||
PLATFORM_HOST = os.environ.get("PLATFORM_SERVICE_HOST", "http://localhost")
|
||||
PLATFORM_PORT = os.environ.get("PLATFORM_SERVICE_PORT", 3001)
|
||||
PROMPT_HOST = os.environ.get("PROMPT_HOST", "http://localhost")
|
||||
PROMPT_PORT = os.environ.get("PROMPT_PORT", 3003)
|
||||
PROMPT_STUDIO_FILE_PATH = os.environ.get(
|
||||
"PROMPT_STUDIO_FILE_PATH", "/app/prompt-studio-data"
|
||||
)
|
||||
X2TEXT_HOST = os.environ.get("X2TEXT_HOST", "http://localhost")
|
||||
X2TEXT_PORT = os.environ.get("X2TEXT_PORT", 3004)
|
||||
STRUCTURE_TOOL_IMAGE_URL = get_required_setting("STRUCTURE_TOOL_IMAGE_URL")
|
||||
STRUCTURE_TOOL_IMAGE_NAME = get_required_setting("STRUCTURE_TOOL_IMAGE_NAME")
|
||||
STRUCTURE_TOOL_IMAGE_TAG = get_required_setting("STRUCTURE_TOOL_IMAGE_TAG")
|
||||
WORKFLOW_DATA_DIR = os.environ.get("WORKFLOW_DATA_DIR")
|
||||
API_STORAGE_DIR = os.environ.get("API_STORAGE_DIR")
|
||||
# Quick-start development settings - unsuitable for production
|
||||
# See https://docs.djangoproject.com/en/4.2/howto/deployment/checklist/
|
||||
|
||||
# SECURITY WARNING: keep the secret key used in production secret!
|
||||
SECRET_KEY = get_required_setting("DJANGO_SECRET_KEY")
|
||||
|
||||
# SECURITY WARNING: don't run with debug turned on in production!
|
||||
DEBUG = True
|
||||
|
||||
ALLOWED_HOSTS = ["*"]
|
||||
CSRF_TRUSTED_ORIGINS = [WEB_APP_ORIGIN_URL]
|
||||
CORS_ALLOW_ALL_ORIGINS = False
|
||||
SESSION_COOKIE_AGE = 86400
|
||||
|
||||
|
||||
# Application definition
|
||||
SHARED_APPS = (
|
||||
# Multitenancy
|
||||
"django_tenants",
|
||||
"corsheaders",
|
||||
# For the organization model
|
||||
"account",
|
||||
# Django apps should go below this line
|
||||
"django.contrib.admin",
|
||||
"django.contrib.auth",
|
||||
"django.contrib.contenttypes",
|
||||
"django.contrib.sessions",
|
||||
"django.contrib.messages",
|
||||
"django.contrib.staticfiles",
|
||||
"django.contrib.admindocs",
|
||||
# Third party apps should go below this line,
|
||||
"rest_framework",
|
||||
# Connector OAuth
|
||||
"connector_auth",
|
||||
"social_django",
|
||||
# Doc generator
|
||||
"drf_yasg",
|
||||
"docs",
|
||||
# Plugins
|
||||
"plugins",
|
||||
"log_events",
|
||||
"feature_flag",
|
||||
"django_celery_beat",
|
||||
)
|
||||
|
||||
TENANT_APPS = (
|
||||
# your tenant-specific apps
|
||||
"django.contrib.admin",
|
||||
"django.contrib.auth",
|
||||
"django.contrib.contenttypes",
|
||||
"django.contrib.sessions",
|
||||
"django.contrib.messages",
|
||||
"django.contrib.staticfiles",
|
||||
"tenant_account",
|
||||
"project",
|
||||
"prompt",
|
||||
"connector",
|
||||
"adapter_processor",
|
||||
"file_management",
|
||||
"workflow_manager.endpoint",
|
||||
"workflow_manager.workflow",
|
||||
"tool_instance",
|
||||
"pipeline",
|
||||
"cron_expression_generator",
|
||||
"platform_settings",
|
||||
"api",
|
||||
"prompt_studio.prompt_profile_manager",
|
||||
"prompt_studio.prompt_studio",
|
||||
"prompt_studio.prompt_studio_core",
|
||||
"prompt_studio.prompt_studio_registry",
|
||||
"prompt_studio.prompt_studio_output_manager",
|
||||
)
|
||||
|
||||
INSTALLED_APPS = list(SHARED_APPS) + [
|
||||
app for app in TENANT_APPS if app not in SHARED_APPS
|
||||
]
|
||||
DEFAULT_MODEL_BACKEND = "django.contrib.auth.backends.ModelBackend"
|
||||
GOOGLE_MODEL_BACKEND = "social_core.backends.google.GoogleOAuth2"
|
||||
|
||||
AUTHENTICATION_BACKENDS = (
|
||||
DEFAULT_MODEL_BACKEND,
|
||||
GOOGLE_MODEL_BACKEND,
|
||||
)
|
||||
|
||||
TENANT_MODEL = "account.Organization"
|
||||
TENANT_DOMAIN_MODEL = "account.Domain"
|
||||
AUTH_USER_MODEL = "account.User"
|
||||
PUBLIC_ORG_ID = "public"
|
||||
|
||||
MIDDLEWARE = [
|
||||
"corsheaders.middleware.CorsMiddleware",
|
||||
"django_tenants.middleware.TenantSubfolderMiddleware",
|
||||
"django.middleware.security.SecurityMiddleware",
|
||||
"django.contrib.sessions.middleware.SessionMiddleware",
|
||||
"django.middleware.csrf.CsrfViewMiddleware",
|
||||
"django.middleware.common.CommonMiddleware",
|
||||
"django.contrib.auth.middleware.AuthenticationMiddleware",
|
||||
"django.contrib.messages.middleware.MessageMiddleware",
|
||||
"django.middleware.clickjacking.XFrameOptionsMiddleware",
|
||||
"account.custom_auth_middleware.CustomAuthMiddleware",
|
||||
"middleware.exception.ExceptionLoggingMiddleware",
|
||||
"social_django.middleware.SocialAuthExceptionMiddleware",
|
||||
]
|
||||
|
||||
PUBLIC_SCHEMA_URLCONF = "backend.public_urls"
|
||||
ROOT_URLCONF = "backend.urls"
|
||||
TENANT_SUBFOLDER_PREFIX = f"/{PATH_PREFIX}/unstract"
|
||||
SHOW_PUBLIC_IF_NO_TENANT_FOUND = True
|
||||
|
||||
TEMPLATES = [
|
||||
{
|
||||
"BACKEND": "django.template.backends.django.DjangoTemplates",
|
||||
"DIRS": [],
|
||||
"APP_DIRS": True,
|
||||
"OPTIONS": {
|
||||
"context_processors": [
|
||||
"django.template.context_processors.debug",
|
||||
"django.template.context_processors.request",
|
||||
"django.contrib.auth.context_processors.auth",
|
||||
"django.contrib.messages.context_processors.messages",
|
||||
],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
WSGI_APPLICATION = "backend.wsgi.application"
|
||||
|
||||
|
||||
# Database
|
||||
# https://docs.djangoproject.com/en/4.2/ref/settings/#databases
|
||||
|
||||
DATABASES = {
|
||||
"default": {
|
||||
"ENGINE": "django_tenants.postgresql_backend",
|
||||
"NAME": f"{DB_NAME}",
|
||||
"USER": f"{DB_USER}",
|
||||
"HOST": f"{DB_HOST}",
|
||||
"PASSWORD": f"{DB_PASSWORD}",
|
||||
"PORT": f"{DB_PORT}",
|
||||
"ATOMIC_REQUESTS": True,
|
||||
}
|
||||
}
|
||||
|
||||
DATABASE_ROUTERS = ("django_tenants.routers.TenantSyncRouter",)
|
||||
|
||||
CACHES = {
|
||||
"default": {
|
||||
"BACKEND": "django_redis.cache.RedisCache",
|
||||
"LOCATION": f"redis://{REDIS_HOST}:{REDIS_PORT}",
|
||||
"OPTIONS": {
|
||||
"CLIENT_CLASS": "django_redis.client.DefaultClient",
|
||||
"SERIALIZER": "django_redis.serializers.json.JSONSerializer",
|
||||
"DB": REDIS_DB,
|
||||
"USERNAME": REDIS_USER,
|
||||
"PASSWORD": REDIS_PASSWORD,
|
||||
},
|
||||
"KEY_FUNCTION": "account.cache_service.custom_key_function",
|
||||
}
|
||||
}
|
||||
|
||||
RQ_QUEUES = {
|
||||
"default": {"USE_REDIS_CACHE": "default"},
|
||||
}
|
||||
|
||||
# Used for asynchronous/Queued execution
|
||||
# Celery based scheduler
|
||||
CELERY_BROKER_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}"
|
||||
|
||||
# CELERY_RESULT_BACKEND = f"redis://{REDIS_HOST}:{REDIS_PORT}/1"
|
||||
# Postgres as result backend
|
||||
CELERY_RESULT_BACKEND = (
|
||||
f"db+postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
|
||||
)
|
||||
CELERY_ACCEPT_CONTENT = ["json"]
|
||||
CELERY_TASK_SERIALIZER = "json"
|
||||
CELERY_RESULT_SERIALIZER = "json"
|
||||
CELERY_TIMEZONE = "UTC"
|
||||
CELERY_TASK_MAX_RETRIES = 3
|
||||
CELERY_TASK_RETRY_BACKOFF = 60 # Time in seconds before retrying the task
|
||||
|
||||
# Feature Flag
|
||||
FEATURE_FLAG_SERVICE_URL = {
|
||||
"evaluate": f"{FLIPT_BASE_URL}/api/v1/flags/evaluate/"
|
||||
}
|
||||
|
||||
SCHEDULER_KWARGS = {
|
||||
"coalesce": True,
|
||||
"misfire_grace_time": 300,
|
||||
"max_instances": 1,
|
||||
"replace_existing": True,
|
||||
}
|
||||
|
||||
# Password validation
|
||||
# https://docs.djangoproject.com/en/4.2/ref/settings/#auth-password-validators
|
||||
|
||||
AUTH_PASSWORD_VALIDATORS = [
|
||||
{
|
||||
"NAME": "django.contrib.auth.password_validation."
|
||||
"UserAttributeSimilarityValidator",
|
||||
},
|
||||
{
|
||||
"NAME": "django.contrib.auth.password_validation."
|
||||
"MinimumLengthValidator",
|
||||
},
|
||||
{
|
||||
"NAME": "django.contrib.auth.password_validation."
|
||||
"CommonPasswordValidator",
|
||||
},
|
||||
{
|
||||
"NAME": "django.contrib.auth.password_validation."
|
||||
"NumericPasswordValidator",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# Internationalization
|
||||
# https://docs.djangoproject.com/en/4.2/topics/i18n/
|
||||
|
||||
LANGUAGE_CODE = "en-us"
|
||||
|
||||
TIME_ZONE = "UTC"
|
||||
|
||||
USE_I18N = True
|
||||
|
||||
USE_TZ = True
|
||||
|
||||
|
||||
# Static files (CSS, JavaScript, Images)
|
||||
# https://docs.djangoproject.com/en/4.2/howto/static-files/
|
||||
STATIC_URL = f"/{PATH_PREFIX}/static/"
|
||||
|
||||
# Default primary key field type
|
||||
# https://docs.djangoproject.com/en/4.2/ref/settings/#default-auto-field
|
||||
|
||||
DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"
|
||||
|
||||
REST_FRAMEWORK = {
|
||||
"DEFAULT_PERMISSION_CLASSES": [], # TODO: Update once auth is figured
|
||||
"TEST_REQUEST_DEFAULT_FORMAT": "json",
|
||||
"EXCEPTION_HANDLER": "middleware.exception.drf_logging_exc_handler",
|
||||
}
|
||||
|
||||
# These paths will work without authentication
|
||||
WHITELISTED_PATHS_LIST = [
|
||||
"/login",
|
||||
"/home",
|
||||
"/callback",
|
||||
"/favicon.ico",
|
||||
"/logout",
|
||||
"/signup",
|
||||
]
|
||||
WHITELISTED_PATHS = [f"/{PATH_PREFIX}{PATH}" for PATH in WHITELISTED_PATHS_LIST]
|
||||
# White lists workflow-api-deployment path
|
||||
WHITELISTED_PATHS.append(f"/{API_DEPLOYMENT_PATH_PREFIX}")
|
||||
|
||||
# White list paths under tenant paths
|
||||
TENANT_ACCESSIBLE_PUBLIC_PATHS_LIST = ["/oauth", "/organization", "/doc"]
|
||||
TENANT_ACCESSIBLE_PUBLIC_PATHS = [
|
||||
f"/{PATH_PREFIX}{PATH}" for PATH in TENANT_ACCESSIBLE_PUBLIC_PATHS_LIST
|
||||
]
|
||||
|
||||
# API Doc Generator Settings
|
||||
# https://drf-yasg.readthedocs.io/en/stable/settings.html
|
||||
REDOC_SETTINGS = {
|
||||
"PATH_IN_MIDDLE": True,
|
||||
"REQUIRED_PROPS_FIRST": True,
|
||||
}
|
||||
|
||||
# Social Auth Settings
|
||||
SOCIAL_AUTH_LOGIN_REDIRECT_URL = (
|
||||
f"{WEB_APP_ORIGIN_URL}/oauth-status/?status=success"
|
||||
)
|
||||
SOCIAL_AUTH_LOGIN_ERROR_URL = f"{WEB_APP_ORIGIN_URL}/oauth-status/?status=error"
|
||||
SOCIAL_AUTH_EXTRA_DATA_EXPIRATION_TIME_IN_SECOND = os.environ.get(
|
||||
"SOCIAL_AUTH_EXTRA_DATA_EXPIRATION_TIME_IN_SECOND", 3600
|
||||
)
|
||||
SOCIAL_AUTH_USER_MODEL = "account.User"
|
||||
SOCIAL_AUTH_STORAGE = "connector_auth.models.ConnectorDjangoStorage"
|
||||
SOCIAL_AUTH_JSONFIELD_ENABLED = True
|
||||
SOCIAL_AUTH_URL_NAMESPACE = "social"
|
||||
SOCIAL_AUTH_FIELDS_STORED_IN_SESSION = ["oauth-key", "connector-guid"]
|
||||
SOCIAL_AUTH_TRAILING_SLASH = False
|
||||
|
||||
for key in [
|
||||
"GOOGLE_OAUTH2_KEY",
|
||||
"GOOGLE_OAUTH2_SECRET",
|
||||
]:
|
||||
exec("SOCIAL_AUTH_{key} = os.environ.get('{key}')".format(key=key))
|
||||
|
||||
SOCIAL_AUTH_PIPELINE = (
|
||||
# Checks if user is authenticated
|
||||
"connector_auth.pipeline.common.check_user_exists",
|
||||
# Gets user details from provider
|
||||
"social_core.pipeline.social_auth.social_details",
|
||||
"social_core.pipeline.social_auth.social_uid",
|
||||
# Cache secrets and fields in redis
|
||||
"connector_auth.pipeline.common.cache_oauth_creds",
|
||||
)
|
||||
|
||||
# Social Auth: Google OAuth2
|
||||
# Default takes care of sign in flow which we don't need for connectors
|
||||
SOCIAL_AUTH_GOOGLE_OAUTH2_IGNORE_DEFAULT_SCOPE = True
|
||||
SOCIAL_AUTH_GOOGLE_OAUTH2_SCOPE = [
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"https://www.googleapis.com/auth/drive",
|
||||
]
|
||||
SOCIAL_AUTH_GOOGLE_OAUTH2_AUTH_EXTRA_ARGUMENTS = {
|
||||
"access_type": "offline",
|
||||
"include_granted_scopes": "true",
|
||||
"prompt": "consent",
|
||||
}
|
||||
SOCIAL_AUTH_GOOGLE_OAUTH2_USE_UNIQUE_USER_ID = True
|
||||
|
||||
|
||||
# Always keep this line at the bottom of the file.
|
||||
if missing_settings:
|
||||
ERROR_MESSAGE = "Below required settings are missing.\n" + ",\n".join(
|
||||
missing_settings
|
||||
)
|
||||
raise ValueError(ERROR_MESSAGE)
|
||||
27
backend/backend/settings/dev.py
Normal file
27
backend/backend/settings/dev.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from backend.settings.base import * # noqa: F401, F403
|
||||
|
||||
DEBUG = True
|
||||
|
||||
X_FRAME_OPTIONS = "http://localhost:3000"
|
||||
X_FRAME_OPTIONS = "ALLOW-FROM http://localhost:3000"
|
||||
|
||||
CORS_ALLOWED_ORIGINS = [
|
||||
"http://localhost:3000",
|
||||
"http://127.0.0.1:3000",
|
||||
"https://dev-3xlzwou1raoituv0.us.auth0.com",
|
||||
# Other allowed origins if needed
|
||||
]
|
||||
|
||||
CORS_ORIGIN_WHITELIST = [
|
||||
"http://localhost:3000",
|
||||
"http://127.0.0.1:3000",
|
||||
"https://dev-3xlzwou1raoituv0.us.auth0.com",
|
||||
# Other allowed origins if needed
|
||||
]
|
||||
|
||||
CORS_ALLOW_METHODS = ["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"]
|
||||
|
||||
CORS_ALLOW_HEADERS = [
|
||||
"authorization",
|
||||
"content-type",
|
||||
]
|
||||
3
backend/backend/settings/test_cases.py
Normal file
3
backend/backend/settings/test_cases.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from backend.settings.base import * # noqa: F401, F403
|
||||
|
||||
DEBUG = True
|
||||
53
backend/backend/urls.py
Normal file
53
backend/backend/urls.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""URL configuration for backend project.
|
||||
|
||||
The `urlpatterns` list routes URLs to views. For more information please see:
|
||||
https://docs.djangoproject.com/en/4.2/topics/http/urls/
|
||||
Examples:
|
||||
Function views
|
||||
1. Add an import: from my_app import views
|
||||
2. Add a URL to urlpatterns: path('', views.home, name='home')
|
||||
Class-based views
|
||||
1. Add an import: from other_app.views import Home
|
||||
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
|
||||
Including another URLconf
|
||||
1. Import the include() function: from django.urls import include, path
|
||||
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
|
||||
"""
|
||||
|
||||
from backend.constants import UrlPathConstants
|
||||
from django.conf.urls import * # noqa: F401, F403
|
||||
from django.urls import include, path
|
||||
|
||||
urlpatterns = [
|
||||
path("", include("tenant_account.urls")),
|
||||
path("", include("prompt.urls")),
|
||||
path("", include("project.urls")),
|
||||
path("", include("connector.urls")),
|
||||
path("", include("connector_processor.urls")),
|
||||
path("", include("adapter_processor.urls")),
|
||||
path("", include("file_management.urls")),
|
||||
path("", include("tool_instance.urls")),
|
||||
path("", include("cron_expression_generator.urls")),
|
||||
path("", include("pipeline.urls")),
|
||||
path("", include("apps.urls")),
|
||||
path("workflow/", include("workflow_manager.urls")),
|
||||
path("platform/", include("platform_settings.urls")),
|
||||
path("api/", include("api.urls")),
|
||||
path(
|
||||
UrlPathConstants.PROMPT_STUDIO,
|
||||
include("prompt_studio.prompt_profile_manager.urls"),
|
||||
),
|
||||
path(
|
||||
UrlPathConstants.PROMPT_STUDIO,
|
||||
include("prompt_studio.prompt_studio.urls"),
|
||||
),
|
||||
path("", include("prompt_studio.prompt_studio_core.urls")),
|
||||
path(
|
||||
UrlPathConstants.PROMPT_STUDIO,
|
||||
include("prompt_studio.prompt_studio_registry.urls"),
|
||||
),
|
||||
path(
|
||||
UrlPathConstants.PROMPT_STUDIO,
|
||||
include("prompt_studio.prompt_studio_output_manager.urls"),
|
||||
),
|
||||
]
|
||||
25
backend/backend/wsgi.py
Normal file
25
backend/backend/wsgi.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""WSGI config for backend project.
|
||||
|
||||
It exposes the WSGI callable as a module-level variable named ``application``.
|
||||
|
||||
For more information on this file, see
|
||||
https://docs.djangoproject.com/en/4.2/howto/deployment/wsgi/
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import socketio
|
||||
from django.conf import settings
|
||||
from django.core.wsgi import get_wsgi_application
|
||||
from dotenv import load_dotenv
|
||||
from log_events.views import sio
|
||||
|
||||
load_dotenv()
|
||||
path_prefix = settings.PATH_PREFIX
|
||||
|
||||
os.environ.setdefault(
|
||||
"DJANGO_SETTINGS_MODULE",
|
||||
os.environ.get("DJANGO_SETTINGS_MODULE", "backend.settings.dev"),
|
||||
)
|
||||
django_app = get_wsgi_application()
|
||||
application = socketio.WSGIApp(sio, django_app,socketio_path=f"{path_prefix}/socket")
|
||||
0
backend/connector/__init__.py
Normal file
0
backend/connector/__init__.py
Normal file
5
backend/connector/admin.py
Normal file
5
backend/connector/admin.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from django.contrib import admin
|
||||
|
||||
from .models import ConnectorInstance
|
||||
|
||||
admin.site.register(ConnectorInstance)
|
||||
5
backend/connector/apps.py
Normal file
5
backend/connector/apps.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class ConnectorConfig(AppConfig):
|
||||
name = "connector"
|
||||
316
backend/connector/connector_instance_helper.py
Normal file
316
backend/connector/connector_instance_helper.py
Normal file
@@ -0,0 +1,316 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from account.models import User
|
||||
from connector.constants import ConnectorInstanceConstant
|
||||
from connector.models import ConnectorInstance
|
||||
from connector.unstract_account import UnstractAccount
|
||||
from django.conf import settings
|
||||
from django.db import connection
|
||||
from unstract.connectors.filesystems.ucs import UnstractCloudStorage
|
||||
from unstract.connectors.filesystems.ucs.constants import UCSKey
|
||||
from workflow_manager.workflow.models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectorInstanceHelper:
|
||||
@staticmethod
|
||||
def create_default_gcs_connector(workflow: Workflow, user: User) -> None:
|
||||
"""Method to create default storage connector.
|
||||
|
||||
Args:
|
||||
org_id (str)
|
||||
workflow (Workflow)
|
||||
user (User)
|
||||
"""
|
||||
org_schema = connection.tenant.schema_name
|
||||
if not user.project_storage_created:
|
||||
logger.info("Creating default storage")
|
||||
account = UnstractAccount(org_schema, user.email)
|
||||
account.provision_s3_storage()
|
||||
account.upload_sample_files()
|
||||
user.project_storage_created = True
|
||||
user.save()
|
||||
logger.info("default storage created successfully.")
|
||||
|
||||
logger.info("Adding connectors to Unstract")
|
||||
connector_name = ConnectorInstanceConstant.USER_STORAGE
|
||||
gcs_id = UnstractCloudStorage.get_id()
|
||||
bucket_name = settings.UNSTRACT_FREE_STORAGE_BUCKET_NAME
|
||||
base_path = f"{bucket_name}/{org_schema}/{user.email}"
|
||||
|
||||
connector_metadata = {
|
||||
UCSKey.KEY: settings.GOOGLE_STORAGE_ACCESS_KEY_ID,
|
||||
UCSKey.SECRET: settings.GOOGLE_STORAGE_SECRET_ACCESS_KEY,
|
||||
UCSKey.BUCKET: bucket_name,
|
||||
UCSKey.ENDPOINT_URL: settings.GOOGLE_STORAGE_BASE_URL,
|
||||
}
|
||||
connector_metadata__input = {
|
||||
**connector_metadata,
|
||||
UCSKey.PATH: base_path + "/input",
|
||||
}
|
||||
connector_metadata__output = {
|
||||
**connector_metadata,
|
||||
UCSKey.PATH: base_path + "/output",
|
||||
}
|
||||
ConnectorInstance.objects.create(
|
||||
connector_name=connector_name,
|
||||
workflow=workflow,
|
||||
created_by=user,
|
||||
connector_id=gcs_id,
|
||||
connector_metadata=connector_metadata__input,
|
||||
connector_type=ConnectorInstance.ConnectorType.INPUT,
|
||||
connector_mode=ConnectorInstance.ConnectorMode.FILE_SYSTEM,
|
||||
)
|
||||
ConnectorInstance.objects.create(
|
||||
connector_name=connector_name,
|
||||
workflow=workflow,
|
||||
created_by=user,
|
||||
connector_id=gcs_id,
|
||||
connector_metadata=connector_metadata__output,
|
||||
connector_type=ConnectorInstance.ConnectorType.OUTPUT,
|
||||
connector_mode=ConnectorInstance.ConnectorMode.FILE_SYSTEM,
|
||||
)
|
||||
logger.info("Connectors added successfully.")
|
||||
|
||||
@staticmethod
|
||||
def get_connector_instances_by_workflow(
|
||||
workflow_id: str,
|
||||
connector_type: tuple[str, str],
|
||||
connector_mode: Optional[tuple[int, str]] = None,
|
||||
values: Optional[list[str]] = None,
|
||||
connector_name: Optional[str] = None,
|
||||
) -> list[ConnectorInstance]:
|
||||
"""Method to get connector instances by workflow.
|
||||
|
||||
Args:
|
||||
workflow_id (str)
|
||||
connector_type (tuple[str, str]): Specifies input/output
|
||||
connector_mode (Optional[tuple[int, str]], optional):
|
||||
Specifies database/file
|
||||
values (Optional[list[str]], optional): Defaults to None.
|
||||
connector_name (Optional[str], optional): Defaults to None.
|
||||
|
||||
Returns:
|
||||
list[ConnectorInstance]
|
||||
"""
|
||||
logger.info(f"Setting connector mode to {connector_mode}")
|
||||
filter_params: dict[str, Any] = {
|
||||
"workflow": workflow_id,
|
||||
"connector_type": connector_type,
|
||||
}
|
||||
if connector_mode is not None:
|
||||
filter_params["connector_mode"] = connector_mode
|
||||
if connector_name is not None:
|
||||
filter_params["connector_name"] = connector_name
|
||||
|
||||
connector_instances = ConnectorInstance.objects.filter(
|
||||
**filter_params
|
||||
).all()
|
||||
logger.info(f"Retrived connector instance values {connector_instances}")
|
||||
if values is not None:
|
||||
filtered_connector_instances = connector_instances.values(*values)
|
||||
logger.info(
|
||||
f"Returning filtered \
|
||||
connector instance value {filtered_connector_instances}"
|
||||
)
|
||||
return list(filtered_connector_instances)
|
||||
logger.info(f"Returning connector instances {connector_instances}")
|
||||
return list(connector_instances)
|
||||
|
||||
@staticmethod
|
||||
def get_connector_instance_by_workflow(
|
||||
workflow_id: str,
|
||||
connector_type: tuple[str, str],
|
||||
connector_mode: Optional[tuple[int, str]] = None,
|
||||
connector_name: Optional[str] = None,
|
||||
) -> Optional[ConnectorInstance]:
|
||||
"""Get one connector instance.
|
||||
|
||||
Use this method if the connector instance is unique for \
|
||||
filter_params
|
||||
Args:
|
||||
workflow_id (str): _description_
|
||||
connector_type (tuple[str, str]): Specifies input/output
|
||||
connector_mode (Optional[tuple[int, str]], optional).
|
||||
Specifies database/filesystem
|
||||
values (Optional[list[str]], optional).
|
||||
connector_name (Optional[str], optional).
|
||||
|
||||
Returns:
|
||||
list[ConnectorInstance]: _description_
|
||||
"""
|
||||
logger.info("Fetching connector instance by workflow")
|
||||
filter_params: dict[str, Any] = {
|
||||
"workflow": workflow_id,
|
||||
"connector_type": connector_type,
|
||||
}
|
||||
if connector_mode is not None:
|
||||
filter_params["connector_mode"] = connector_mode
|
||||
if connector_name is not None:
|
||||
filter_params["connector_name"] = connector_name
|
||||
|
||||
try:
|
||||
connector_instance: ConnectorInstance = (
|
||||
ConnectorInstance.objects.filter(**filter_params).first()
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"Error occured while fetching connector instances {exc}"
|
||||
)
|
||||
raise exc
|
||||
|
||||
return connector_instance
|
||||
|
||||
@staticmethod
|
||||
def get_input_connector_instance_by_name_for_workflow(
|
||||
workflow_id: str,
|
||||
connector_name: str,
|
||||
) -> Optional[ConnectorInstance]:
|
||||
"""Method to get Input connector instance name from the workflow.
|
||||
|
||||
Args:
|
||||
workflow_id (str)
|
||||
connector_name (str)
|
||||
|
||||
Returns:
|
||||
Optional[ConnectorInstance]
|
||||
"""
|
||||
return ConnectorInstanceHelper.get_connector_instance_by_workflow(
|
||||
workflow_id=workflow_id,
|
||||
connector_type=ConnectorInstance.ConnectorType.INPUT,
|
||||
connector_name=connector_name,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_output_connector_instance_by_name_for_workflow(
|
||||
workflow_id: str,
|
||||
connector_name: str,
|
||||
) -> Optional[ConnectorInstance]:
|
||||
"""Method to get output connector name by Workflow.
|
||||
|
||||
Args:
|
||||
workflow_id (str)
|
||||
connector_name (str)
|
||||
|
||||
Returns:
|
||||
Optional[ConnectorInstance]
|
||||
"""
|
||||
return ConnectorInstanceHelper.get_connector_instance_by_workflow(
|
||||
workflow_id=workflow_id,
|
||||
connector_type=ConnectorInstance.ConnectorType.OUTPUT,
|
||||
connector_name=connector_name,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_input_connector_instances_by_workflow(
|
||||
workflow_id: str,
|
||||
) -> list[ConnectorInstance]:
|
||||
"""Method to get connector instances by workflow.
|
||||
|
||||
Args:
|
||||
workflow_id (str)
|
||||
|
||||
Returns:
|
||||
list[ConnectorInstance]
|
||||
"""
|
||||
return ConnectorInstanceHelper.get_connector_instances_by_workflow(
|
||||
workflow_id, ConnectorInstance.ConnectorType.INPUT
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_output_connector_instances_by_workflow(
|
||||
workflow_id: str,
|
||||
) -> list[ConnectorInstance]:
|
||||
"""Method to get output connector instances by workflow.
|
||||
|
||||
Args:
|
||||
workflow_id (str): _description_
|
||||
|
||||
Returns:
|
||||
list[ConnectorInstance]: _description_
|
||||
"""
|
||||
return ConnectorInstanceHelper.get_connector_instances_by_workflow(
|
||||
workflow_id, ConnectorInstance.ConnectorType.OUTPUT
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_file_system_input_connector_instances_by_workflow(
|
||||
workflow_id: str, values: Optional[list[str]] = None
|
||||
) -> list[ConnectorInstance]:
|
||||
"""Method to fetch file system connector by workflow.
|
||||
|
||||
Args:
|
||||
workflow_id (str):
|
||||
values (Optional[list[str]], optional)
|
||||
|
||||
Returns:
|
||||
list[ConnectorInstance]
|
||||
"""
|
||||
return ConnectorInstanceHelper.get_connector_instances_by_workflow(
|
||||
workflow_id,
|
||||
ConnectorInstance.ConnectorType.INPUT,
|
||||
ConnectorInstance.ConnectorMode.FILE_SYSTEM,
|
||||
values,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_file_system_output_connector_instances_by_workflow(
|
||||
workflow_id: str, values: Optional[list[str]] = None
|
||||
) -> list[ConnectorInstance]:
|
||||
"""Method to get file system output connector by workflow.
|
||||
|
||||
Args:
|
||||
workflow_id (str)
|
||||
values (Optional[list[str]], optional)
|
||||
|
||||
Returns:
|
||||
list[ConnectorInstance]
|
||||
"""
|
||||
return ConnectorInstanceHelper.get_connector_instances_by_workflow(
|
||||
workflow_id,
|
||||
ConnectorInstance.ConnectorType.OUTPUT,
|
||||
ConnectorInstance.ConnectorMode.FILE_SYSTEM,
|
||||
values,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_database_input_connector_instances_by_workflow(
|
||||
workflow_id: str, values: Optional[list[str]] = None
|
||||
) -> list[ConnectorInstance]:
|
||||
"""Method to fetch input database connectors by workflow.
|
||||
|
||||
Args:
|
||||
workflow_id (str)
|
||||
values (Optional[list[str]], optional)
|
||||
|
||||
Returns:
|
||||
list[ConnectorInstance]
|
||||
"""
|
||||
return ConnectorInstanceHelper.get_connector_instances_by_workflow(
|
||||
workflow_id,
|
||||
ConnectorInstance.ConnectorType.INPUT,
|
||||
ConnectorInstance.ConnectorMode.DATABASE,
|
||||
values,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_database_output_connector_instances_by_workflow(
|
||||
workflow_id: str, values: Optional[list[str]] = None
|
||||
) -> list[ConnectorInstance]:
|
||||
"""Method to fetch output database connectors by workflow.
|
||||
|
||||
Args:
|
||||
workflow_id (str)
|
||||
values (Optional[list[str]], optional)
|
||||
|
||||
Returns:
|
||||
list[ConnectorInstance]
|
||||
"""
|
||||
return ConnectorInstanceHelper.get_connector_instances_by_workflow(
|
||||
workflow_id,
|
||||
ConnectorInstance.ConnectorType.OUTPUT,
|
||||
ConnectorInstance.ConnectorMode.DATABASE,
|
||||
values,
|
||||
)
|
||||
17
backend/connector/constants.py
Normal file
17
backend/connector/constants.py
Normal file
@@ -0,0 +1,17 @@
|
||||
class ConnectorInstanceKey:
|
||||
CONNECTOR_ID = "connector_id"
|
||||
CONNECTOR_NAME = "connector_name"
|
||||
CONNECTOR_TYPE = "connector_type"
|
||||
CONNECTOR_MODE = "connector_mode"
|
||||
CONNECTOR_VERSION = "connector_version"
|
||||
CONNECTOR_AUTH = "connector_auth"
|
||||
CONNECTOR_METADATA = "connector_metadata"
|
||||
CONNECTOR_METADATA_B = "connector_metadata_b"
|
||||
CONNECTOR_EXISTS = (
|
||||
"Connector with this configuration already exists in this project."
|
||||
)
|
||||
DUPLICATE_API = "It appears that a duplicate call may have been made."
|
||||
|
||||
|
||||
class ConnectorInstanceConstant:
|
||||
USER_STORAGE = "User Storage"
|
||||
38
backend/connector/fields.py
Normal file
38
backend/connector/fields.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from datetime import datetime
|
||||
|
||||
from connector_auth.constants import SocialAuthConstants
|
||||
from connector_auth.models import ConnectorAuth
|
||||
from django.db import models
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectorAuthJSONField(models.JSONField):
|
||||
def from_db_value(self, value, expression, connection): # type: ignore
|
||||
""" Overrding default function. """
|
||||
metadata = super().from_db_value(value, expression, connection)
|
||||
provider = metadata.get(SocialAuthConstants.PROVIDER)
|
||||
uid = metadata.get(SocialAuthConstants.UID)
|
||||
if provider and uid:
|
||||
refresh_after_str = metadata.get(SocialAuthConstants.REFRESH_AFTER)
|
||||
if refresh_after_str:
|
||||
refresh_after = datetime.strptime(
|
||||
refresh_after_str, SocialAuthConstants.REFRESH_AFTER_FORMAT
|
||||
)
|
||||
if datetime.now() > refresh_after:
|
||||
metadata = self._refresh_tokens(provider, uid)
|
||||
return metadata
|
||||
|
||||
def _refresh_tokens(self, provider: str, uid: str) -> dict[str, str]:
|
||||
"""Retrieves PSA object and refreshes the token if necessary."""
|
||||
connector_auth: ConnectorAuth = ConnectorAuth.get_social_auth(
|
||||
provider=provider, uid=uid
|
||||
)
|
||||
tokens_refreshed = False
|
||||
if connector_auth:
|
||||
(
|
||||
connector_metadata,
|
||||
tokens_refreshed,
|
||||
) = connector_auth.get_and_refresh_tokens()
|
||||
return connector_metadata # type: ignore
|
||||
122
backend/connector/migrations/0001_initial.py
Normal file
122
backend/connector/migrations/0001_initial.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Generated by Django 4.2.1 on 2024-01-23 11:18
|
||||
|
||||
import uuid
|
||||
|
||||
import connector.fields
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
("project", "0001_initial"),
|
||||
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||
("workflow", "0001_initial"),
|
||||
("connector_auth", "0001_initial"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="ConnectorInstance",
|
||||
fields=[
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("modified_at", models.DateTimeField(auto_now=True)),
|
||||
(
|
||||
"id",
|
||||
models.UUIDField(
|
||||
default=uuid.uuid4,
|
||||
editable=False,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
),
|
||||
),
|
||||
("connector_name", models.TextField(max_length=128)),
|
||||
("connector_id", models.CharField(default="", max_length=128)),
|
||||
(
|
||||
"connector_metadata",
|
||||
connector.fields.ConnectorAuthJSONField(
|
||||
db_column="connector_metadata", default=dict
|
||||
),
|
||||
),
|
||||
(
|
||||
"connector_version",
|
||||
models.CharField(default="", max_length=64),
|
||||
),
|
||||
(
|
||||
"connector_type",
|
||||
models.CharField(
|
||||
choices=[("INPUT", "Input"), ("OUTPUT", "Output")]
|
||||
),
|
||||
),
|
||||
(
|
||||
"connector_mode",
|
||||
models.CharField(
|
||||
choices=[
|
||||
(0, "UNKNOWN"),
|
||||
(1, "FILE_SYSTEM"),
|
||||
(2, "DATABASE"),
|
||||
],
|
||||
db_comment="0: UNKNOWN, 1: FILE_SYSTEM, 2: DATABASE",
|
||||
default=0,
|
||||
),
|
||||
),
|
||||
(
|
||||
"connector_auth",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
to="connector_auth.connectorauth",
|
||||
),
|
||||
),
|
||||
(
|
||||
"created_by",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
related_name="created_connectors",
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
(
|
||||
"modified_by",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
related_name="modified_connectors",
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
(
|
||||
"project",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="project_connectors",
|
||||
to="project.project",
|
||||
),
|
||||
),
|
||||
(
|
||||
"workflow",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="workflow_connectors",
|
||||
to="workflow.workflow",
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
migrations.AddConstraint(
|
||||
model_name="connectorinstance",
|
||||
constraint=models.UniqueConstraint(
|
||||
fields=("connector_name", "workflow", "connector_type"),
|
||||
name="unique_connector",
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,42 @@
|
||||
# Generated by Django 4.2.1 on 2024-02-16 06:50
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from account.models import EncryptionSecret
|
||||
from connector.models import ConnectorInstance
|
||||
from cryptography.fernet import Fernet
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("connector", "0001_initial"),
|
||||
("account", "0005_encryptionsecret"),
|
||||
]
|
||||
|
||||
def EncryptCredentials(apps: Any, schema_editor: Any) -> None:
|
||||
encryption_secret: EncryptionSecret = EncryptionSecret.objects.get()
|
||||
f: Fernet = Fernet(encryption_secret.key.encode("utf-8"))
|
||||
queryset = ConnectorInstance.objects.all()
|
||||
|
||||
for obj in queryset: # type: ignore
|
||||
# Access attributes of the object
|
||||
|
||||
if hasattr(obj, "connector_metadata"):
|
||||
json_string: str = json.dumps(obj.connector_metadata)
|
||||
obj.connector_metadata_b = f.encrypt(
|
||||
json_string.encode("utf-8")
|
||||
)
|
||||
obj.save()
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="connectorinstance",
|
||||
name="connector_metadata_b",
|
||||
field=models.BinaryField(null=True),
|
||||
),
|
||||
migrations.RunPython(
|
||||
EncryptCredentials, reverse_code=migrations.RunPython.noop
|
||||
),
|
||||
]
|
||||
0
backend/connector/migrations/__init__.py
Normal file
0
backend/connector/migrations/__init__.py
Normal file
116
backend/connector/models.py
Normal file
116
backend/connector/models.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import uuid
|
||||
|
||||
from account.models import User
|
||||
from connector.fields import ConnectorAuthJSONField
|
||||
from connector_auth.models import ConnectorAuth
|
||||
from connector_processor.connector_processor import ConnectorProcessor
|
||||
from connector_processor.constants import ConnectorKeys
|
||||
from django.db import models
|
||||
from project.models import Project
|
||||
from utils.models.base_model import BaseModel
|
||||
from workflow_manager.workflow.models import Workflow
|
||||
|
||||
from backend.constants import FieldLengthConstants as FLC
|
||||
|
||||
CONNECTOR_NAME_SIZE = 128
|
||||
VERSION_NAME_SIZE = 64
|
||||
|
||||
|
||||
class ConnectorInstance(BaseModel):
|
||||
class ConnectorType(models.TextChoices):
|
||||
INPUT = "INPUT", "Input"
|
||||
OUTPUT = "OUTPUT", "Output"
|
||||
|
||||
class ConnectorMode(models.IntegerChoices):
|
||||
UNKNOWN = 0, "UNKNOWN"
|
||||
FILE_SYSTEM = 1, "FILE_SYSTEM"
|
||||
DATABASE = 2, "DATABASE"
|
||||
|
||||
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
|
||||
connector_name = models.TextField(
|
||||
max_length=CONNECTOR_NAME_SIZE, null=False, blank=False
|
||||
)
|
||||
project = models.ForeignKey(
|
||||
Project,
|
||||
on_delete=models.CASCADE,
|
||||
related_name="project_connectors",
|
||||
null=True,
|
||||
blank=True,
|
||||
)
|
||||
workflow = models.ForeignKey(
|
||||
Workflow,
|
||||
on_delete=models.CASCADE,
|
||||
related_name="workflow_connectors",
|
||||
null=False,
|
||||
blank=False,
|
||||
)
|
||||
connector_id = models.CharField(
|
||||
max_length=FLC.CONNECTOR_ID_LENGTH, default=""
|
||||
)
|
||||
# TODO Required to be removed
|
||||
connector_metadata = ConnectorAuthJSONField(
|
||||
db_column="connector_metadata", null=False, blank=False, default=dict
|
||||
)
|
||||
connector_metadata_b = models.BinaryField(null=True)
|
||||
connector_version = models.CharField(
|
||||
max_length=VERSION_NAME_SIZE, default=""
|
||||
)
|
||||
connector_type = models.CharField(choices=ConnectorType.choices)
|
||||
connector_auth = models.ForeignKey(
|
||||
ConnectorAuth, on_delete=models.SET_NULL, null=True, blank=True
|
||||
)
|
||||
connector_mode = models.CharField(
|
||||
choices=ConnectorMode.choices,
|
||||
default=ConnectorMode.UNKNOWN,
|
||||
db_comment="0: UNKNOWN, 1: FILE_SYSTEM, 2: DATABASE",
|
||||
)
|
||||
|
||||
created_by = models.ForeignKey(
|
||||
User,
|
||||
on_delete=models.SET_NULL,
|
||||
related_name="created_connectors",
|
||||
null=True,
|
||||
blank=True,
|
||||
)
|
||||
modified_by = models.ForeignKey(
|
||||
User,
|
||||
on_delete=models.SET_NULL,
|
||||
related_name="modified_connectors",
|
||||
null=True,
|
||||
blank=True,
|
||||
)
|
||||
|
||||
def get_connector_metadata(self) -> dict[str, str]:
|
||||
"""Gets connector metadata and refreshes the tokens if needed in case
|
||||
of OAuth."""
|
||||
tokens_refreshed = False
|
||||
if self.connector_auth:
|
||||
(
|
||||
self.connector_metadata,
|
||||
tokens_refreshed,
|
||||
) = self.connector_auth.get_and_refresh_tokens()
|
||||
if tokens_refreshed:
|
||||
self.save()
|
||||
return self.connector_metadata
|
||||
|
||||
@staticmethod
|
||||
def supportsOAuth(connector_id: str) -> bool:
|
||||
return bool(
|
||||
ConnectorProcessor.get_connector_data_with_key(
|
||||
connector_id, ConnectorKeys.OAUTH
|
||||
)
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"Connector({self.id}, type{self.connector_type},"
|
||||
f" workflow: {self.workflow})"
|
||||
)
|
||||
|
||||
class Meta:
|
||||
constraints = [
|
||||
models.UniqueConstraint(
|
||||
fields=["connector_name", "workflow", "connector_type"],
|
||||
name="unique_connector",
|
||||
),
|
||||
]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user