Fix/write-db-connector-error-handling (#330)
* fix connector insert * error adding srtatrted * error handling for postgres/snowflake * testc cases for postgres/snowflake * testc cases for postgres/snowflake/redshift * small change * test cases covered for postgres, redshift * test cases covered for bigquery * test cases covered for postgres,snowflake,redshoft,big query * test cases covered for postgres,snowflake,redshoft,big query,mssql * test cases covered for postgres,snowflake,redshoft,big query,mssql,mariadb,mysql * adding lock file * addig celery change * test connector changes * removeing test connectionm * removeing Keynotfound exception * removeing Keynotfound exception * refactoring db_utils * sample.test.env changes * fixing celery part on test cases * fixing celery part on test cases * refactoring db_utils * refactor connectors * refactor connectors * refactor big query connectors * small change * small change in db_utils * removing print in database_utils * PR review * small change * small change * PR comment change --------- Co-authored-by: Ritwik G <100672805+ritwik-g@users.noreply.github.com>
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -132,6 +132,7 @@ celerybeat.pid
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
test*.env
|
||||
.env
|
||||
.env.export
|
||||
.venv*
|
||||
|
||||
16
backend/pdm.lock
generated
16
backend/pdm.lock
generated
@@ -3453,6 +3453,20 @@ files = [
|
||||
{file = "pytest-8.2.2.tar.gz", hash = "sha256:de4bb8104e201939ccdc688b27a89a7be2079b22e2bd2b07f806b6ba71117977"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-dotenv"
|
||||
version = "0.5.2"
|
||||
summary = "A py.test plugin that parses environment files before running tests"
|
||||
groups = ["test"]
|
||||
dependencies = [
|
||||
"pytest>=5.0.0",
|
||||
"python-dotenv>=0.9.1",
|
||||
]
|
||||
files = [
|
||||
{file = "pytest-dotenv-0.5.2.tar.gz", hash = "sha256:2dc6c3ac6d8764c71c6d2804e902d0ff810fa19692e95fe138aefc9b1aa73732"},
|
||||
{file = "pytest_dotenv-0.5.2-py3-none-any.whl", hash = "sha256:40a2cece120a213898afaa5407673f6bd924b1fa7eafce6bda0e8abffe2f710f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "python-crontab"
|
||||
version = "3.1.0"
|
||||
@@ -3484,7 +3498,7 @@ name = "python-dotenv"
|
||||
version = "1.0.0"
|
||||
requires_python = ">=3.8"
|
||||
summary = "Read key-value pairs from a .env file and set them as environment variables"
|
||||
groups = ["default"]
|
||||
groups = ["default", "test"]
|
||||
files = [
|
||||
{file = "python-dotenv-1.0.0.tar.gz", hash = "sha256:a8df96034aae6d2d50a4ebe8216326c61c3eb64836776504fcca410e5937a3ba"},
|
||||
{file = "python_dotenv-1.0.0-py3-none-any.whl", hash = "sha256:f5971a9226b701070a4bf2c38c89e5a3f0d64de8debda981d1db98583009122a"},
|
||||
|
||||
@@ -66,8 +66,13 @@ deploy = [
|
||||
]
|
||||
test = [
|
||||
"pytest>=8.0.1",
|
||||
"pytest-dotenv==0.5.2",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
env_files = "test.env" # Load env from particular env file
|
||||
addopts = "-s"
|
||||
|
||||
[tool.pdm.scripts]
|
||||
# Commands for backend
|
||||
backend.cmd = "./entrypoint.sh"
|
||||
|
||||
40
backend/sample.test.env
Normal file
40
backend/sample.test.env
Normal file
@@ -0,0 +1,40 @@
|
||||
DB_HOST='unstract-db'
|
||||
DB_USER='unstract_dev'
|
||||
DB_PASSWORD='unstract_pass'
|
||||
DB_NAME='unstract_db'
|
||||
DB_PORT=5432
|
||||
|
||||
|
||||
REDSHIFT_HOST=
|
||||
REDSHIFT_PORT="5439"
|
||||
REDSHIFT_DB=
|
||||
REDSHIFT_USER=
|
||||
REDSHIFT_PASSWORD=
|
||||
|
||||
BIGQUERY_CREDS=
|
||||
|
||||
SNOWFLAKE_USER=
|
||||
SNOWFLAKE_PASSWORD=
|
||||
SNOWFLAKE_ACCOUNT=
|
||||
SNOWFLAKE_ROLE=
|
||||
SNOWFLAKE_DB=
|
||||
SNOWFLAKE_SCHEMA=
|
||||
SNOWFLAKE_WAREHOUSE=
|
||||
|
||||
MSSQL_SERVER=
|
||||
MSSQL_PORT=
|
||||
MSSQL_PASSWORD=
|
||||
MSSQL_DB=
|
||||
MSSQL_USER=
|
||||
|
||||
MYSQL_SERVER=
|
||||
MYSQL_PORT=
|
||||
MYSQL_PASSWORD=
|
||||
MYSQL_DB=
|
||||
MYSQL_USER=
|
||||
|
||||
MARIADB_SERVER=
|
||||
MARIADB_PORT=
|
||||
MARIADB_PASSWORD=
|
||||
MARIADB_DB=
|
||||
MARIADB_USER=
|
||||
@@ -5,8 +5,9 @@ class TableColumns:
|
||||
|
||||
|
||||
class DBConnectionClass:
|
||||
SNOWFLAKE = "SnowflakeConnection"
|
||||
BIGQUERY = "Client"
|
||||
SNOWFLAKE = "SnowflakeDB"
|
||||
BIGQUERY = "BigQuery"
|
||||
MSSQL = "MSSQL"
|
||||
|
||||
|
||||
class Snowflake:
|
||||
|
||||
@@ -8,22 +8,27 @@ from utils.constants import Common
|
||||
from workflow_manager.endpoint.constants import (
|
||||
BigQuery,
|
||||
DBConnectionClass,
|
||||
Snowflake,
|
||||
TableColumns,
|
||||
)
|
||||
from workflow_manager.endpoint.exceptions import BigQueryTableNotFound
|
||||
from workflow_manager.endpoint.db_connector_helper import DBConnectorQueryHelper
|
||||
from workflow_manager.endpoint.exceptions import (
|
||||
BigQueryTableNotFound,
|
||||
UnstractDBException,
|
||||
)
|
||||
from workflow_manager.workflow.enums import AgentName, ColumnModes
|
||||
|
||||
from unstract.connectors.databases import connectors as db_connectors
|
||||
from unstract.connectors.databases.exceptions import UnstractDBConnectorException
|
||||
from unstract.connectors.databases.unstract_db import UnstractDB
|
||||
from unstract.connectors.exceptions import ConnectorError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseUtils:
|
||||
@staticmethod
|
||||
def make_sql_values_for_query(
|
||||
values: dict[str, Any], column_types: dict[str, str], cls: Any = None
|
||||
def get_sql_values_for_query(
|
||||
values: dict[str, Any], column_types: dict[str, str], cls_name: str
|
||||
) -> dict[str, str]:
|
||||
"""Making Sql Columns and Values for Query.
|
||||
|
||||
@@ -51,32 +56,29 @@ class DatabaseUtils:
|
||||
insert into the table accordingly
|
||||
"""
|
||||
sql_values: dict[str, Any] = {}
|
||||
|
||||
for column in values:
|
||||
if cls == DBConnectionClass.SNOWFLAKE:
|
||||
if cls_name == DBConnectionClass.SNOWFLAKE:
|
||||
col = column.lower()
|
||||
type_x = column_types[col]
|
||||
if type_x in Snowflake.COLUMN_TYPES:
|
||||
sql_values[column] = f"'{values[column]}'"
|
||||
elif type_x == "VARIANT":
|
||||
if type_x == "VARIANT":
|
||||
values[column] = values[column].replace("'", "\\'")
|
||||
sql_values[column] = f"parse_json($${values[column]}$$)"
|
||||
else:
|
||||
sql_values[column] = f"{values[column]}"
|
||||
elif cls == DBConnectionClass.BIGQUERY:
|
||||
elif cls_name == DBConnectionClass.BIGQUERY:
|
||||
col = column.lower()
|
||||
type_x = column_types[col]
|
||||
if type_x in BigQuery.COLUMN_TYPES:
|
||||
sql_values[column] = f"{type_x}('{values[column]}')"
|
||||
sql_values[column] = f"{type_x}({values[column]})"
|
||||
else:
|
||||
sql_values[column] = f"'{values[column]}'"
|
||||
sql_values[column] = f"{values[column]}"
|
||||
else:
|
||||
# Default to Other SQL DBs
|
||||
# TODO: Handle numeric types with no quotes
|
||||
sql_values[column] = f"'{values[column]}'"
|
||||
sql_values[column] = f"{values[column]}"
|
||||
if column_types.get("id"):
|
||||
uuid_id = str(uuid.uuid4())
|
||||
sql_values["id"] = f"'{uuid_id}'"
|
||||
sql_values["id"] = f"{uuid_id}"
|
||||
return sql_values
|
||||
|
||||
@staticmethod
|
||||
@@ -96,7 +98,7 @@ class DatabaseUtils:
|
||||
|
||||
@staticmethod
|
||||
def get_column_types(
|
||||
cls: Any,
|
||||
cls_name: Any,
|
||||
table_name: str,
|
||||
connector_id: str,
|
||||
connector_settings: dict[str, Any],
|
||||
@@ -118,7 +120,7 @@ class DatabaseUtils:
|
||||
"""
|
||||
column_types: dict[str, str] = {}
|
||||
try:
|
||||
if cls == DBConnectionClass.SNOWFLAKE:
|
||||
if cls_name == DBConnectionClass.SNOWFLAKE:
|
||||
query = f"describe table {table_name}"
|
||||
results = DatabaseUtils.execute_and_fetch_data(
|
||||
connector_id=connector_id,
|
||||
@@ -127,7 +129,7 @@ class DatabaseUtils:
|
||||
)
|
||||
for column in results:
|
||||
column_types[column[0].lower()] = column[1].split("(")[0]
|
||||
elif cls == DBConnectionClass.BIGQUERY:
|
||||
elif cls_name == DBConnectionClass.BIGQUERY:
|
||||
bigquery_table_name = str.lower(table_name).split(".")
|
||||
if len(bigquery_table_name) != BigQuery.TABLE_NAME_SIZE:
|
||||
raise BigQueryTableNotFound()
|
||||
@@ -223,8 +225,8 @@ class DatabaseUtils:
|
||||
return values
|
||||
|
||||
@staticmethod
|
||||
def get_sql_columns_and_values_for_query(
|
||||
engine: Any,
|
||||
def get_sql_query_data(
|
||||
cls_name: str,
|
||||
connector_id: str,
|
||||
connector_settings: dict[str, Any],
|
||||
table_name: str,
|
||||
@@ -234,7 +236,6 @@ class DatabaseUtils:
|
||||
provided values and table schema.
|
||||
|
||||
Args:
|
||||
engine (Any): The database engine.
|
||||
connector_id: The connector id of the connector provided
|
||||
connector_settings: Connector settings provided by user
|
||||
table_name (str): The name of the target table for the insert query.
|
||||
@@ -252,26 +253,26 @@ class DatabaseUtils:
|
||||
- For other SQL databases, it uses default SQL generation
|
||||
based on column types.
|
||||
"""
|
||||
class_name = engine.__class__.__name__
|
||||
column_types: dict[str, str] = DatabaseUtils.get_column_types(
|
||||
cls=class_name,
|
||||
cls_name=cls_name,
|
||||
table_name=table_name,
|
||||
connector_id=connector_id,
|
||||
connector_settings=connector_settings,
|
||||
)
|
||||
sql_columns_and_values = DatabaseUtils.make_sql_values_for_query(
|
||||
sql_columns_and_values = DatabaseUtils.get_sql_values_for_query(
|
||||
values=values,
|
||||
column_types=column_types,
|
||||
cls=class_name,
|
||||
cls_name=cls_name,
|
||||
)
|
||||
return sql_columns_and_values
|
||||
|
||||
@staticmethod
|
||||
def execute_write_query(
|
||||
db_class: UnstractDB,
|
||||
engine: Any,
|
||||
table_name: str,
|
||||
sql_keys: list[str],
|
||||
sql_values: list[str],
|
||||
sql_values: Any,
|
||||
) -> None:
|
||||
"""Execute Insert Query.
|
||||
|
||||
@@ -284,22 +285,30 @@ class DatabaseUtils:
|
||||
- Snowflake does not support INSERT INTO ... VALUES ...
|
||||
syntax when VARIANT columns are present (JSON).
|
||||
So we need to use INSERT INTO ... SELECT ... syntax
|
||||
- sql values can contain data with single quote. It needs to
|
||||
"""
|
||||
sql = (
|
||||
f"INSERT INTO {table_name} ({','.join(sql_keys)}) "
|
||||
f"SELECT {','.join(sql_values)}"
|
||||
cls_name = db_class.__class__.__name__
|
||||
sql = DBConnectorQueryHelper.build_sql_insert_query(
|
||||
cls_name=cls_name, table_name=table_name, sql_keys=sql_keys
|
||||
)
|
||||
logger.debug(f"insertng into table with: {sql} query")
|
||||
logger.debug(f"inserting into table {table_name} with: {sql} query")
|
||||
|
||||
sql_values = DBConnectorQueryHelper.prepare_sql_values(
|
||||
cls_name=cls_name, sql_values=sql_values, sql_keys=sql_keys
|
||||
)
|
||||
logger.debug(f"sql_values: {sql_values}")
|
||||
|
||||
try:
|
||||
if hasattr(engine, "cursor"):
|
||||
with engine.cursor() as cursor:
|
||||
cursor.execute(sql)
|
||||
engine.commit()
|
||||
else:
|
||||
engine.query(sql)
|
||||
except Exception as e:
|
||||
logger.error(f"Error while writing data: {str(e)}")
|
||||
raise e
|
||||
db_class.execute_query(
|
||||
engine=engine,
|
||||
sql_query=sql,
|
||||
sql_values=sql_values,
|
||||
table_name=table_name,
|
||||
sql_keys=sql_keys,
|
||||
)
|
||||
except UnstractDBConnectorException as e:
|
||||
raise UnstractDBException(detail=e.detail) from e
|
||||
logger.debug(f"sucessfully inserted into table {table_name} with: {sql} query")
|
||||
|
||||
@staticmethod
|
||||
def get_db_class(
|
||||
@@ -315,7 +324,10 @@ class DatabaseUtils:
|
||||
) -> Any:
|
||||
connector = db_connectors[connector_id][Common.METADATA][Common.CONNECTOR]
|
||||
connector_class: UnstractDB = connector(connector_settings)
|
||||
return connector_class.execute(query=query)
|
||||
try:
|
||||
return connector_class.execute(query=query)
|
||||
except ConnectorError as e:
|
||||
raise UnstractDBException(detail=e.message) from e
|
||||
|
||||
@staticmethod
|
||||
def create_table_if_not_exists(
|
||||
@@ -328,7 +340,6 @@ class DatabaseUtils:
|
||||
|
||||
Args:
|
||||
class_name (UnstractDB): Type of Unstract DB connector
|
||||
engine (Any): _description_
|
||||
table_name (str): _description_
|
||||
database_entry (dict[str, Any]): _description_
|
||||
|
||||
@@ -338,62 +349,9 @@ class DatabaseUtils:
|
||||
sql = DBConnectorQueryHelper.create_table_query(
|
||||
conn_cls=db_class, table=table_name, database_entry=database_entry
|
||||
)
|
||||
logger.debug(f"creating table with: {sql} query")
|
||||
logger.debug(f"creating table {table_name} with: {sql} query")
|
||||
try:
|
||||
if hasattr(engine, "cursor"):
|
||||
with engine.cursor() as cursor:
|
||||
cursor.execute(sql)
|
||||
engine.commit()
|
||||
else:
|
||||
engine.query(sql)
|
||||
except Exception as e:
|
||||
logger.error(f"Error while creating table: {str(e)}")
|
||||
raise e
|
||||
|
||||
|
||||
class DBConnectorQueryHelper:
|
||||
"""A class that helps to generate query for connector table operations."""
|
||||
|
||||
@staticmethod
|
||||
def create_table_query(
|
||||
conn_cls: UnstractDB, table: str, database_entry: dict[str, Any]
|
||||
) -> Any:
|
||||
sql_query = ""
|
||||
"""Generate a SQL query to create a table, based on the provided
|
||||
database entry.
|
||||
|
||||
Args:
|
||||
conn_cls (str): The database connector class.
|
||||
Should be one of 'BIGQUERY', 'SNOWFLAKE', or other.
|
||||
table (str): The name of the table to be created.
|
||||
database_entry (dict[str, Any]):
|
||||
A dictionary containing column names as keys
|
||||
and their corresponding values.
|
||||
|
||||
These values are used to determine the data types,
|
||||
for the columns in the table.
|
||||
|
||||
Returns:
|
||||
str: A SQL query string to create a table with the specified name,
|
||||
and column definitions.
|
||||
|
||||
Note:
|
||||
- Each conn_cls have it's implementation for SQL create table query
|
||||
Based on the implementation, a base SQL create table query will be
|
||||
created containing Permanent columns
|
||||
- Each conn_cls also has a mapping to convert python datatype to
|
||||
corresponding column type (string, VARCHAR etc)
|
||||
- keys in database_entry will be converted to column type, and
|
||||
column values will be the valus in database_entry
|
||||
- base SQL create table will be appended based column type and
|
||||
values, and generates a complete SQL create table query
|
||||
"""
|
||||
create_table_query = conn_cls.get_create_table_query(table=table)
|
||||
sql_query += create_table_query
|
||||
|
||||
for key, val in database_entry.items():
|
||||
if key not in TableColumns.PERMANENT_COLUMNS:
|
||||
sql_type = conn_cls.sql_to_db_mapping(val)
|
||||
sql_query += f"{key} {sql_type}, "
|
||||
|
||||
return sql_query.rstrip(", ") + ");"
|
||||
db_class.execute_query(engine=engine, sql_query=sql, sql_values=None)
|
||||
except UnstractDBConnectorException as e:
|
||||
raise UnstractDBException(detail=e.detail) from e
|
||||
logger.debug(f"successfully created table {table_name} with: {sql} query")
|
||||
|
||||
77
backend/workflow_manager/endpoint/db_connector_helper.py
Normal file
77
backend/workflow_manager/endpoint/db_connector_helper.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from typing import Any
|
||||
|
||||
from google.cloud import bigquery
|
||||
from workflow_manager.endpoint.constants import DBConnectionClass, TableColumns
|
||||
|
||||
from unstract.connectors.databases.unstract_db import UnstractDB
|
||||
|
||||
|
||||
class DBConnectorQueryHelper:
|
||||
"""A class that helps to generate query for connector table operations."""
|
||||
|
||||
@staticmethod
|
||||
def create_table_query(
|
||||
conn_cls: UnstractDB, table: str, database_entry: dict[str, Any]
|
||||
) -> Any:
|
||||
sql_query = ""
|
||||
"""Generate a SQL query to create a table, based on the provided
|
||||
database entry.
|
||||
|
||||
Args:
|
||||
conn_cls (str): The database connector class.
|
||||
Should be one of 'BIGQUERY', 'SNOWFLAKE', or other.
|
||||
table (str): The name of the table to be created.
|
||||
database_entry (dict[str, Any]):
|
||||
A dictionary containing column names as keys
|
||||
and their corresponding values.
|
||||
|
||||
These values are used to determine the data types,
|
||||
for the columns in the table.
|
||||
|
||||
Returns:
|
||||
str: A SQL query string to create a table with the specified name,
|
||||
and column definitions.
|
||||
|
||||
Note:
|
||||
- Each conn_cls have it's implementation for SQL create table query
|
||||
Based on the implementation, a base SQL create table query will be
|
||||
created containing Permanent columns
|
||||
- Each conn_cls also has a mapping to convert python datatype to
|
||||
corresponding column type (string, VARCHAR etc)
|
||||
- keys in database_entry will be converted to column type, and
|
||||
column values will be the valus in database_entry
|
||||
- base SQL create table will be appended based column type and
|
||||
values, and generates a complete SQL create table query
|
||||
"""
|
||||
create_table_query = conn_cls.get_create_table_query(table=table)
|
||||
sql_query += create_table_query
|
||||
|
||||
for key, val in database_entry.items():
|
||||
if key not in TableColumns.PERMANENT_COLUMNS:
|
||||
sql_type = conn_cls.sql_to_db_mapping(val)
|
||||
sql_query += f"{key} {sql_type}, "
|
||||
|
||||
return sql_query.rstrip(", ") + ");"
|
||||
|
||||
@staticmethod
|
||||
def build_sql_insert_query(
|
||||
cls_name: str, table_name: str, sql_keys: list[str]
|
||||
) -> str:
|
||||
keys_str = ",".join(sql_keys)
|
||||
if cls_name == DBConnectionClass.BIGQUERY:
|
||||
values_placeholder = ",".join(["@" + key for key in sql_keys])
|
||||
else:
|
||||
values_placeholder = ",".join(["%s" for _ in sql_keys])
|
||||
return f"INSERT INTO {table_name} ({keys_str}) VALUES ({values_placeholder})"
|
||||
|
||||
@staticmethod
|
||||
def prepare_sql_values(cls_name: str, sql_values: Any, sql_keys: list[str]) -> Any:
|
||||
if cls_name == DBConnectionClass.MSSQL:
|
||||
return tuple(sql_values)
|
||||
elif cls_name == DBConnectionClass.BIGQUERY:
|
||||
query_parameters = [
|
||||
bigquery.ScalarQueryParameter(key, "STRING", value)
|
||||
for key, value in zip(sql_keys, sql_values)
|
||||
]
|
||||
return bigquery.QueryJobConfig(query_parameters=query_parameters)
|
||||
return sql_values
|
||||
@@ -214,14 +214,16 @@ class DestinationConnector(BaseConnector):
|
||||
table_name=table_name,
|
||||
database_entry=values,
|
||||
)
|
||||
sql_columns_and_values = DatabaseUtils.get_sql_columns_and_values_for_query(
|
||||
engine=engine,
|
||||
cls_name = db_class.__class__.__name__
|
||||
sql_columns_and_values = DatabaseUtils.get_sql_query_data(
|
||||
cls_name=cls_name,
|
||||
connector_id=connector_instance.connector_id,
|
||||
connector_settings=connector_settings,
|
||||
table_name=table_name,
|
||||
values=values,
|
||||
)
|
||||
DatabaseUtils.execute_write_query(
|
||||
db_class=db_class,
|
||||
engine=engine,
|
||||
table_name=table_name,
|
||||
sql_keys=list(sql_columns_and_values.keys()),
|
||||
|
||||
@@ -69,3 +69,11 @@ class BigQueryTableNotFound(APIException):
|
||||
"Please enter correct correct bigquery table in the form "
|
||||
"{table}.{schema}.{database}."
|
||||
)
|
||||
|
||||
|
||||
class UnstractDBException(APIException):
|
||||
default_detail = "Error creating/inserting to database. "
|
||||
|
||||
def __init__(self, detail: str = default_detail) -> None:
|
||||
status_code = 500
|
||||
super().__init__(detail=detail, code=status_code)
|
||||
|
||||
3
backend/workflow_manager/endpoint/tests/__init__.py
Normal file
3
backend/workflow_manager/endpoint/tests/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from backend.celery import app as celery_app
|
||||
|
||||
__all__ = ["celery_app"]
|
||||
@@ -0,0 +1,160 @@
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import pytest # type: ignore
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from unstract.connectors.databases.bigquery import BigQuery
|
||||
from unstract.connectors.databases.mariadb import MariaDB
|
||||
from unstract.connectors.databases.mssql import MSSQL
|
||||
from unstract.connectors.databases.mysql import MySQL
|
||||
from unstract.connectors.databases.postgresql import PostgreSQL
|
||||
from unstract.connectors.databases.redshift import Redshift
|
||||
from unstract.connectors.databases.snowflake import SnowflakeDB
|
||||
from unstract.connectors.databases.unstract_db import UnstractDB
|
||||
|
||||
load_dotenv("test.env")
|
||||
|
||||
|
||||
class BaseTestDB:
|
||||
@pytest.fixture(autouse=True)
|
||||
def base_setup(self) -> None:
|
||||
self.postgres_creds = {
|
||||
"user": os.getenv("DB_USER"),
|
||||
"password": os.getenv("DB_PASSWORD"),
|
||||
"host": os.getenv("DB_HOST"),
|
||||
"port": os.getenv("DB_PORT"),
|
||||
"database": os.getenv("DB_NAME"),
|
||||
}
|
||||
self.redshift_creds = {
|
||||
"user": os.getenv("REDSHIFT_USER"),
|
||||
"password": os.getenv("REDSHIFT_PASSWORD"),
|
||||
"host": os.getenv("REDSHIFT_HOST"),
|
||||
"port": os.getenv("REDSHIFT_PORT"),
|
||||
"database": os.getenv("REDSHIFT_DB"),
|
||||
}
|
||||
self.snowflake_creds = {
|
||||
"user": os.getenv("SNOWFLAKE_USER"),
|
||||
"password": os.getenv("SNOWFLAKE_PASSWORD"),
|
||||
"account": os.getenv("SNOWFLAKE_ACCOUNT"),
|
||||
"role": os.getenv("SNOWFLAKE_ROLE"),
|
||||
"database": os.getenv("SNOWFLAKE_DB"),
|
||||
"schema": os.getenv("SNOWFLAKE_SCHEMA"),
|
||||
"warehouse": os.getenv("SNOWFLAKE_WAREHOUSE"),
|
||||
}
|
||||
self.mssql_creds = {
|
||||
"user": os.getenv("MSSQL_USER"),
|
||||
"password": os.getenv("MSSQL_PASSWORD"),
|
||||
"server": os.getenv("MSSQL_SERVER"),
|
||||
"port": os.getenv("MSSQL_PORT"),
|
||||
"database": os.getenv("MSSQL_DB"),
|
||||
}
|
||||
self.mysql_creds = {
|
||||
"user": os.getenv("MYSQL_USER"),
|
||||
"password": os.getenv("MYSQL_PASSWORD"),
|
||||
"host": os.getenv("MYSQL_SERVER"),
|
||||
"port": os.getenv("MYSQL_PORT"),
|
||||
"database": os.getenv("MYSQL_DB"),
|
||||
}
|
||||
self.mariadb_creds = {
|
||||
"user": os.getenv("MARIADB_USER"),
|
||||
"password": os.getenv("MARIADB_PASSWORD"),
|
||||
"host": os.getenv("MARIADB_SERVER"),
|
||||
"port": os.getenv("MARIADB_PORT"),
|
||||
"database": os.getenv("MARIADB_DB"),
|
||||
}
|
||||
self.database_entry = {
|
||||
"created_by": "Unstract/DBWriter",
|
||||
"created_at": datetime.datetime(2024, 5, 20, 7, 46, 57, 307998),
|
||||
"data": '{"input_file": "simple.pdf", "result": "report"}',
|
||||
}
|
||||
valid_schema_name = "public"
|
||||
invalid_schema_name = "public_1"
|
||||
self.valid_postgres_creds = {**self.postgres_creds, "schema": valid_schema_name}
|
||||
self.invalid_postgres_creds = {
|
||||
**self.postgres_creds,
|
||||
"schema": invalid_schema_name,
|
||||
}
|
||||
self.valid_redshift_creds = {**self.redshift_creds, "schema": valid_schema_name}
|
||||
self.invalid_redshift_creds = {
|
||||
**self.redshift_creds,
|
||||
"schema": invalid_schema_name,
|
||||
}
|
||||
self.invalid_syntax_table_name = "invalid-syntax.name.test_output"
|
||||
self.invalid_wrong_table_name = "database.schema.test_output"
|
||||
self.valid_table_name = "test_output"
|
||||
bigquery_json_str = os.getenv("BIGQUERY_CREDS", "{}")
|
||||
self.bigquery_settings = json.loads(bigquery_json_str)
|
||||
self.bigquery_settings["json_credentials"] = bigquery_json_str
|
||||
self.valid_bigquery_table_name = "pandoras-tamer.bigquery_test.bigquery_output"
|
||||
self.invalid_snowflake_db = {**self.snowflake_creds, "database": "invalid"}
|
||||
self.invalid_snowflake_schema = {**self.snowflake_creds, "schema": "invalid"}
|
||||
self.invalid_snowflake_warehouse = {
|
||||
**self.snowflake_creds,
|
||||
"warehouse": "invalid",
|
||||
}
|
||||
|
||||
# Gets all valid db instances except
|
||||
# Bigquery (table name needs to be writted separately for bigquery)
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
("valid_postgres_creds", PostgreSQL),
|
||||
("snowflake_creds", SnowflakeDB),
|
||||
("mssql_creds", MSSQL),
|
||||
("mysql_creds", MySQL),
|
||||
("mariadb_creds", MariaDB),
|
||||
("valid_redshift_creds", Redshift),
|
||||
]
|
||||
)
|
||||
def valid_dbs_instance(self, request: Any) -> Any:
|
||||
return self.get_db_instance(request=request)
|
||||
|
||||
# Gets all valid db instances except:
|
||||
# Bigquery (table name needs to be writted separately for bigquery)
|
||||
# Redshift (can't process more than 64KB character type)
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
("valid_postgres_creds", PostgreSQL),
|
||||
("snowflake_creds", SnowflakeDB),
|
||||
("mssql_creds", MSSQL),
|
||||
("mysql_creds", MySQL),
|
||||
("mariadb_creds", MariaDB),
|
||||
]
|
||||
)
|
||||
def valid_dbs_instance_to_handle_large_doc(self, request: Any) -> Any:
|
||||
return self.get_db_instance(request=request)
|
||||
|
||||
def get_db_instance(self, request: Any) -> UnstractDB:
|
||||
creds_name, db_class = request.param
|
||||
creds = getattr(self, creds_name)
|
||||
if not creds:
|
||||
pytest.fail(f"Unknown credentials: {creds_name}")
|
||||
db_instance = db_class(settings=creds)
|
||||
return db_instance
|
||||
|
||||
# Gets all invalid-db instances for postgres, redshift:
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
("invalid_postgres_creds", PostgreSQL),
|
||||
("invalid_redshift_creds", Redshift),
|
||||
]
|
||||
)
|
||||
def invalid_dbs_instance(self, request: Any) -> Any:
|
||||
return self.get_db_instance(request=request)
|
||||
|
||||
@pytest.fixture
|
||||
def valid_bigquery_db_instance(self) -> Any:
|
||||
return BigQuery(settings=self.bigquery_settings)
|
||||
|
||||
# Gets all invalid-db instances for snowflake:
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
("invalid_snowflake_db", SnowflakeDB),
|
||||
("invalid_snowflake_schema", SnowflakeDB),
|
||||
("invalid_snowflake_warehouse", SnowflakeDB),
|
||||
]
|
||||
)
|
||||
def invalid_snowflake_db_instance(self, request: Any) -> Any:
|
||||
return self.get_db_instance(request=request)
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1,97 @@
|
||||
import pytest # type: ignore
|
||||
from workflow_manager.endpoint.database_utils import DatabaseUtils
|
||||
from workflow_manager.endpoint.exceptions import UnstractDBException
|
||||
|
||||
from unstract.connectors.databases.unstract_db import UnstractDB
|
||||
|
||||
from .base_test_db import BaseTestDB
|
||||
|
||||
|
||||
class TestCreateTableIfNotExists(BaseTestDB):
|
||||
def test_create_table_if_not_exists_valid(
|
||||
self, valid_dbs_instance: UnstractDB
|
||||
) -> None:
|
||||
engine = valid_dbs_instance.get_engine()
|
||||
result = DatabaseUtils.create_table_if_not_exists(
|
||||
db_class=valid_dbs_instance,
|
||||
engine=engine,
|
||||
table_name=self.valid_table_name,
|
||||
database_entry=self.database_entry,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_create_table_if_not_exists_bigquery_valid(
|
||||
self, valid_bigquery_db_instance: UnstractDB
|
||||
) -> None:
|
||||
engine = valid_bigquery_db_instance.get_engine()
|
||||
result = DatabaseUtils.create_table_if_not_exists(
|
||||
db_class=valid_bigquery_db_instance,
|
||||
engine=engine,
|
||||
table_name=self.valid_bigquery_table_name,
|
||||
database_entry=self.database_entry,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_create_table_if_not_exists_invalid_schema(
|
||||
self, invalid_dbs_instance: UnstractDB
|
||||
) -> None:
|
||||
engine = invalid_dbs_instance.get_engine()
|
||||
with pytest.raises(UnstractDBException):
|
||||
DatabaseUtils.create_table_if_not_exists(
|
||||
db_class=invalid_dbs_instance,
|
||||
engine=engine,
|
||||
table_name=self.valid_table_name,
|
||||
database_entry=self.database_entry,
|
||||
)
|
||||
|
||||
def test_create_table_if_not_exists_invalid_syntax(
|
||||
self, valid_dbs_instance: UnstractDB
|
||||
) -> None:
|
||||
engine = valid_dbs_instance.get_engine()
|
||||
with pytest.raises(UnstractDBException):
|
||||
DatabaseUtils.create_table_if_not_exists(
|
||||
db_class=valid_dbs_instance,
|
||||
engine=engine,
|
||||
table_name=self.invalid_syntax_table_name,
|
||||
database_entry=self.database_entry,
|
||||
)
|
||||
|
||||
def test_create_table_if_not_exists_wrong_table_name(
|
||||
self, valid_dbs_instance: UnstractDB
|
||||
) -> None:
|
||||
engine = valid_dbs_instance.get_engine()
|
||||
with pytest.raises(UnstractDBException):
|
||||
DatabaseUtils.create_table_if_not_exists(
|
||||
db_class=valid_dbs_instance,
|
||||
engine=engine,
|
||||
table_name=self.invalid_wrong_table_name,
|
||||
database_entry=self.database_entry,
|
||||
)
|
||||
|
||||
def test_create_table_if_not_exists_feature_not_supported(
|
||||
self, invalid_dbs_instance: UnstractDB
|
||||
) -> None:
|
||||
engine = invalid_dbs_instance.get_engine()
|
||||
with pytest.raises(UnstractDBException):
|
||||
DatabaseUtils.create_table_if_not_exists(
|
||||
db_class=invalid_dbs_instance,
|
||||
engine=engine,
|
||||
table_name=self.invalid_wrong_table_name,
|
||||
database_entry=self.database_entry,
|
||||
)
|
||||
|
||||
def test_create_table_if_not_exists_invalid_snowflake_db(
|
||||
self, invalid_snowflake_db_instance: UnstractDB
|
||||
) -> None:
|
||||
engine = invalid_snowflake_db_instance.get_engine()
|
||||
with pytest.raises(UnstractDBException):
|
||||
DatabaseUtils.create_table_if_not_exists(
|
||||
db_class=invalid_snowflake_db_instance,
|
||||
engine=engine,
|
||||
table_name=self.invalid_wrong_table_name,
|
||||
database_entry=self.database_entry,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
||||
@@ -0,0 +1,169 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import pytest # type: ignore
|
||||
from workflow_manager.endpoint.database_utils import DatabaseUtils
|
||||
from workflow_manager.endpoint.exceptions import UnstractDBException
|
||||
|
||||
from unstract.connectors.databases.redshift import Redshift
|
||||
from unstract.connectors.databases.unstract_db import UnstractDB
|
||||
|
||||
from .base_test_db import BaseTestDB
|
||||
|
||||
|
||||
class TestExecuteWriteQuery(BaseTestDB):
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, base_setup: Any) -> None:
|
||||
self.sql_columns_and_values = {
|
||||
"created_by": "Unstract/DBWriter",
|
||||
"created_at": "2024-05-20 10:36:25.362609",
|
||||
"data": '{"input_file": "simple.pdf", "result": "report"}',
|
||||
"id": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
def test_execute_write_query_valid(self, valid_dbs_instance: Any) -> None:
|
||||
engine = valid_dbs_instance.get_engine()
|
||||
result = DatabaseUtils.execute_write_query(
|
||||
db_class=valid_dbs_instance,
|
||||
engine=engine,
|
||||
table_name=self.valid_table_name,
|
||||
sql_keys=list(self.sql_columns_and_values.keys()),
|
||||
sql_values=list(self.sql_columns_and_values.values()),
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_execute_write_query_invalid_schema(
|
||||
self, invalid_dbs_instance: Any
|
||||
) -> None:
|
||||
engine = invalid_dbs_instance.get_engine()
|
||||
with pytest.raises(UnstractDBException):
|
||||
DatabaseUtils.execute_write_query(
|
||||
db_class=invalid_dbs_instance,
|
||||
engine=engine,
|
||||
table_name=self.valid_table_name,
|
||||
sql_keys=list(self.sql_columns_and_values.keys()),
|
||||
sql_values=list(self.sql_columns_and_values.values()),
|
||||
)
|
||||
|
||||
def test_execute_write_query_invalid_syntax(self, valid_dbs_instance: Any) -> None:
|
||||
engine = valid_dbs_instance.get_engine()
|
||||
with pytest.raises(UnstractDBException):
|
||||
DatabaseUtils.execute_write_query(
|
||||
db_class=valid_dbs_instance,
|
||||
engine=engine,
|
||||
table_name=self.invalid_syntax_table_name,
|
||||
sql_keys=list(self.sql_columns_and_values.keys()),
|
||||
sql_values=list(self.sql_columns_and_values.values()),
|
||||
)
|
||||
|
||||
def test_execute_write_query_feature_not_supported(
|
||||
self, invalid_dbs_instance: Any
|
||||
) -> None:
|
||||
engine = invalid_dbs_instance.get_engine()
|
||||
with pytest.raises(UnstractDBException):
|
||||
DatabaseUtils.execute_write_query(
|
||||
db_class=invalid_dbs_instance,
|
||||
engine=engine,
|
||||
table_name=self.invalid_wrong_table_name,
|
||||
sql_keys=list(self.sql_columns_and_values.keys()),
|
||||
sql_values=list(self.sql_columns_and_values.values()),
|
||||
)
|
||||
|
||||
def load_text_to_sql_values(self) -> dict[str, Any]:
|
||||
file_path = os.path.join(os.path.dirname(__file__), "static", "large_doc.txt")
|
||||
with open(file_path, encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
sql_columns_and_values = self.sql_columns_and_values.copy()
|
||||
sql_columns_and_values["data"] = content
|
||||
return sql_columns_and_values
|
||||
|
||||
@pytest.fixture
|
||||
def valid_redshift_db_instance(self) -> Any:
|
||||
return Redshift(self.valid_redshift_creds)
|
||||
|
||||
def test_execute_write_query_datatype_too_large_redshift(
|
||||
self, valid_redshift_db_instance: Any
|
||||
) -> None:
|
||||
engine = valid_redshift_db_instance.get_engine()
|
||||
sql_columns_and_values = self.load_text_to_sql_values()
|
||||
with pytest.raises(UnstractDBException):
|
||||
DatabaseUtils.execute_write_query(
|
||||
db_class=valid_redshift_db_instance,
|
||||
engine=engine,
|
||||
table_name=self.valid_table_name,
|
||||
sql_keys=list(sql_columns_and_values.keys()),
|
||||
sql_values=list(sql_columns_and_values.values()),
|
||||
)
|
||||
|
||||
def test_execute_write_query_bigquery_valid(
|
||||
self, valid_bigquery_db_instance: Any
|
||||
) -> None:
|
||||
engine = valid_bigquery_db_instance.get_engine()
|
||||
result = DatabaseUtils.execute_write_query(
|
||||
db_class=valid_bigquery_db_instance,
|
||||
engine=engine,
|
||||
table_name=self.valid_bigquery_table_name,
|
||||
sql_keys=list(self.sql_columns_and_values.keys()),
|
||||
sql_values=list(self.sql_columns_and_values.values()),
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_execute_write_query_wrong_table_name(
|
||||
self, valid_dbs_instance: UnstractDB
|
||||
) -> None:
|
||||
engine = valid_dbs_instance.get_engine()
|
||||
with pytest.raises(UnstractDBException):
|
||||
DatabaseUtils.execute_write_query(
|
||||
db_class=valid_dbs_instance,
|
||||
engine=engine,
|
||||
table_name=self.invalid_wrong_table_name,
|
||||
sql_keys=list(self.sql_columns_and_values.keys()),
|
||||
sql_values=list(self.sql_columns_and_values.values()),
|
||||
)
|
||||
|
||||
def test_execute_write_query_bigquery_large_doc(
|
||||
self, valid_bigquery_db_instance: Any
|
||||
) -> None:
|
||||
engine = valid_bigquery_db_instance.get_engine()
|
||||
sql_columns_and_values = self.load_text_to_sql_values()
|
||||
result = DatabaseUtils.execute_write_query(
|
||||
db_class=valid_bigquery_db_instance,
|
||||
engine=engine,
|
||||
table_name=self.valid_bigquery_table_name,
|
||||
sql_keys=list(sql_columns_and_values.keys()),
|
||||
sql_values=list(sql_columns_and_values.values()),
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_execute_write_query_invalid_snowflake_db(
|
||||
self, invalid_snowflake_db_instance: UnstractDB
|
||||
) -> None:
|
||||
engine = invalid_snowflake_db_instance.get_engine()
|
||||
with pytest.raises(UnstractDBException):
|
||||
DatabaseUtils.execute_write_query(
|
||||
db_class=invalid_snowflake_db_instance,
|
||||
engine=engine,
|
||||
table_name=self.invalid_wrong_table_name,
|
||||
sql_keys=list(self.sql_columns_and_values.keys()),
|
||||
sql_values=list(self.sql_columns_and_values.values()),
|
||||
)
|
||||
|
||||
# Make this function at last to cover all large doc
|
||||
def test_execute_write_query_large_doc(
|
||||
self, valid_dbs_instance_to_handle_large_doc: Any
|
||||
) -> None:
|
||||
engine = valid_dbs_instance_to_handle_large_doc.get_engine()
|
||||
sql_columns_and_values = self.load_text_to_sql_values()
|
||||
result = DatabaseUtils.execute_write_query(
|
||||
db_class=valid_dbs_instance_to_handle_large_doc,
|
||||
engine=engine,
|
||||
table_name=self.valid_table_name,
|
||||
sql_keys=list(sql_columns_and_values.keys()),
|
||||
sql_values=list(sql_columns_and_values.values()),
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
||||
@@ -1,14 +1,22 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import google.api_core.exceptions
|
||||
from google.cloud import bigquery
|
||||
from google.cloud.bigquery import Client
|
||||
|
||||
from unstract.connectors.databases.exceptions import (
|
||||
BigQueryForbiddenException,
|
||||
BigQueryNotFoundException,
|
||||
)
|
||||
from unstract.connectors.databases.unstract_db import UnstractDB
|
||||
from unstract.connectors.exceptions import ConnectorError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BigQuery(UnstractDB):
|
||||
def __init__(self, settings: dict[str, Any]):
|
||||
@@ -87,3 +95,24 @@ class BigQuery(UnstractDB):
|
||||
f"created_by string, created_at TIMESTAMP, "
|
||||
)
|
||||
return sql_query
|
||||
|
||||
def execute_query(
|
||||
self, engine: Any, sql_query: str, sql_values: Any, **kwargs: Any
|
||||
) -> None:
|
||||
table_name = str(kwargs.get("table_name"))
|
||||
try:
|
||||
if sql_values:
|
||||
engine.query(sql_query, job_config=sql_values)
|
||||
else:
|
||||
engine.query(sql_query)
|
||||
except google.api_core.exceptions.Forbidden as e:
|
||||
logger.error(f"Forbidden exception in creating/inserting data: {str(e)}")
|
||||
raise BigQueryForbiddenException(
|
||||
detail=e.message,
|
||||
table_name=table_name,
|
||||
) from e
|
||||
except google.api_core.exceptions.NotFound as e:
|
||||
logger.error(f"Resource not found in creating/inserting table: {str(e)}")
|
||||
raise BigQueryNotFoundException(
|
||||
detail=e.message, table_name=table_name
|
||||
) from e
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
from typing import Any
|
||||
|
||||
from unstract.connectors.exceptions import ConnectorBaseException
|
||||
|
||||
|
||||
class UnstractDBConnectorException(ConnectorBaseException):
|
||||
"""Base class for database-related exceptions from Unstract connectors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
detail: Any,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
default_detail = "Error creating/inserting to database. "
|
||||
user_message = default_detail if not detail else detail
|
||||
super().__init__(*args, user_message=user_message, **kwargs)
|
||||
self.detail = user_message
|
||||
|
||||
|
||||
class InvalidSyntaxException(UnstractDBConnectorException):
|
||||
|
||||
def __init__(self, detail: Any, database: Any) -> None:
|
||||
default_detail = (
|
||||
f"Error creating/writing to {database}. Syntax incorrect. "
|
||||
f"Please check your table-name or schema. "
|
||||
)
|
||||
super().__init__(detail=default_detail + detail)
|
||||
|
||||
|
||||
class InvalidSchemaException(UnstractDBConnectorException):
|
||||
def __init__(self, detail: Any, database: str) -> None:
|
||||
default_detail = f"Error creating/writing to {database}. Schema not valid. "
|
||||
super().__init__(detail=default_detail + detail)
|
||||
|
||||
|
||||
class UnderfinedTableException(UnstractDBConnectorException):
|
||||
def __init__(self, detail: Any, database: str) -> None:
|
||||
default_detail = (
|
||||
f"Error creating/writing to {database}. Undefined table. "
|
||||
f"Please check your table-name or schema. "
|
||||
)
|
||||
super().__init__(detail=default_detail + detail)
|
||||
|
||||
|
||||
class ValueTooLongException(UnstractDBConnectorException):
|
||||
def __init__(self, detail: Any, database: str) -> None:
|
||||
default_detail = (
|
||||
f"Error creating/writing to {database}. "
|
||||
f"Size of the inserted data exceeds the limit provided by the database. "
|
||||
)
|
||||
super().__init__(detail=default_detail + detail)
|
||||
|
||||
|
||||
class FeatureNotSupportedException(UnstractDBConnectorException):
|
||||
|
||||
def __init__(self, detail: Any, database: str) -> None:
|
||||
default_detail = (
|
||||
f"Error creating/writing to {database}. "
|
||||
f"Feature not supported sql error. "
|
||||
)
|
||||
super().__init__(detail=default_detail + detail)
|
||||
|
||||
|
||||
class SnowflakeProgrammingException(UnstractDBConnectorException):
|
||||
|
||||
def __init__(self, detail: Any, database: str) -> None:
|
||||
default_detail = (
|
||||
f"Error creating/writing to {database}. "
|
||||
f"Please check your snowflake credentials. "
|
||||
)
|
||||
super().__init__(default_detail + detail)
|
||||
|
||||
|
||||
class BigQueryForbiddenException(UnstractDBConnectorException):
|
||||
|
||||
def __init__(self, detail: Any, table_name: str) -> None:
|
||||
default_detail = (
|
||||
f"Error creating/writing to {table_name}. "
|
||||
f"Access forbidden in bigquery. Please check your permissions. "
|
||||
)
|
||||
super().__init__(detail=default_detail + detail)
|
||||
|
||||
|
||||
class BigQueryNotFoundException(UnstractDBConnectorException):
|
||||
|
||||
def __init__(self, detail: str, table_name: str) -> None:
|
||||
default_detail = (
|
||||
f"Error creating/writing to {table_name}. "
|
||||
f"The requested resource was not found. "
|
||||
)
|
||||
super().__init__(detail=default_detail + detail)
|
||||
@@ -0,0 +1,20 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ExceptionHelper:
|
||||
@staticmethod
|
||||
def extract_byte_exception(e: Exception) -> Any:
|
||||
"""_summary_
|
||||
Extract error details from byte_exception.
|
||||
Used by mssql
|
||||
Args:
|
||||
e (Exception): _description_
|
||||
|
||||
Returns:
|
||||
Any: _description_
|
||||
"""
|
||||
error_message = str(e)
|
||||
error_code, error_details = eval(error_message)
|
||||
if isinstance(error_details, bytes):
|
||||
error_details = error_details.decode("utf-8")
|
||||
return error_details
|
||||
@@ -5,10 +5,11 @@ import pymysql
|
||||
from pymysql.connections import Connection
|
||||
|
||||
from unstract.connectors.databases.mysql import MySQL
|
||||
from unstract.connectors.databases.mysql_handler import MysqlHandler
|
||||
from unstract.connectors.databases.unstract_db import UnstractDB
|
||||
|
||||
|
||||
class MariaDB(UnstractDB):
|
||||
class MariaDB(UnstractDB, MysqlHandler):
|
||||
def __init__(self, settings: dict[str, Any]):
|
||||
super().__init__("MariaDB")
|
||||
|
||||
@@ -74,3 +75,13 @@ class MariaDB(UnstractDB):
|
||||
Mysql and Mariadb share same SQL column type
|
||||
"""
|
||||
return str(MySQL.sql_to_db_mapping(value=value))
|
||||
|
||||
def execute_query(
|
||||
self, engine: Any, sql_query: str, sql_values: Any, **kwargs: Any
|
||||
) -> None:
|
||||
MysqlHandler.execute_query(
|
||||
engine=engine,
|
||||
sql_query=sql_query,
|
||||
sql_values=sql_values,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import pymssql
|
||||
from pymssql import Connection
|
||||
import pymssql._pymssql as PyMssql
|
||||
from pymssql import Connection # type: ignore
|
||||
|
||||
from unstract.connectors.databases.exceptions import InvalidSyntaxException
|
||||
from unstract.connectors.databases.exceptions_helper import ExceptionHelper
|
||||
from unstract.connectors.databases.unstract_db import UnstractDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MSSQL(UnstractDB):
|
||||
def __init__(self, settings: dict[str, Any]):
|
||||
@@ -49,16 +55,9 @@ class MSSQL(UnstractDB):
|
||||
return True
|
||||
|
||||
def get_engine(self) -> Connection:
|
||||
if self.port:
|
||||
return pymssql.connect(
|
||||
server=self.server,
|
||||
port=self.port,
|
||||
user=self.user,
|
||||
password=self.password,
|
||||
database=self.database,
|
||||
)
|
||||
return pymssql.connect(
|
||||
return pymssql.connect( # type: ignore
|
||||
server=self.server,
|
||||
port=self.port,
|
||||
user=self.user,
|
||||
password=self.password,
|
||||
database=self.database,
|
||||
@@ -74,3 +73,22 @@ class MSSQL(UnstractDB):
|
||||
f"created_by TEXT, created_at DATETIMEOFFSET, "
|
||||
)
|
||||
return sql_query
|
||||
|
||||
def execute_query(
|
||||
self, engine: Any, sql_query: str, sql_values: Any, **kwargs: Any
|
||||
) -> None:
|
||||
try:
|
||||
with engine.cursor() as cursor:
|
||||
if sql_values:
|
||||
cursor.execute(sql_query, sql_values)
|
||||
else:
|
||||
cursor.execute(sql_query)
|
||||
engine.commit()
|
||||
except (PyMssql.ProgrammingError, PyMssql.OperationalError) as e:
|
||||
error_details = ExceptionHelper.extract_byte_exception(e=e)
|
||||
logger.error(
|
||||
f"Invalid syntax in creating/inserting mssql data: {error_details}"
|
||||
)
|
||||
raise InvalidSyntaxException(
|
||||
detail=error_details, database=self.database
|
||||
) from e
|
||||
|
||||
@@ -5,10 +5,11 @@ from typing import Any
|
||||
import pymysql
|
||||
from pymysql.connections import Connection
|
||||
|
||||
from unstract.connectors.databases.mysql_handler import MysqlHandler
|
||||
from unstract.connectors.databases.unstract_db import UnstractDB
|
||||
|
||||
|
||||
class MySQL(UnstractDB):
|
||||
class MySQL(UnstractDB, MysqlHandler):
|
||||
def __init__(self, settings: dict[str, Any]):
|
||||
super().__init__("MySQL")
|
||||
|
||||
@@ -78,3 +79,13 @@ class MySQL(UnstractDB):
|
||||
datetime.datetime: "TIMESTAMP",
|
||||
}
|
||||
return mapping.get(python_type, "LONGTEXT")
|
||||
|
||||
def execute_query(
|
||||
self, engine: Any, sql_query: str, sql_values: Any, **kwargs: Any
|
||||
) -> None:
|
||||
MysqlHandler.execute_query(
|
||||
engine=engine,
|
||||
sql_query=sql_query,
|
||||
sql_values=sql_values,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import pymysql.err as MysqlError
|
||||
|
||||
from unstract.connectors.databases.exceptions import InvalidSyntaxException
|
||||
from unstract.connectors.databases.exceptions_helper import ExceptionHelper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MysqlHandler:
|
||||
@staticmethod
|
||||
def execute_query(
|
||||
engine: Any, sql_query: str, sql_values: Any, database: Any
|
||||
) -> None:
|
||||
try:
|
||||
with engine.cursor() as cursor:
|
||||
if sql_values:
|
||||
cursor.execute(sql_query, sql_values)
|
||||
else:
|
||||
cursor.execute(sql_query)
|
||||
engine.commit()
|
||||
except MysqlError.ProgrammingError as e:
|
||||
error_details = ExceptionHelper.extract_byte_exception(e=e)
|
||||
logger.error(
|
||||
f"Invalid syntax in creating/inserting mysql data: {error_details}"
|
||||
)
|
||||
raise InvalidSyntaxException(detail=error_details, database=database) from e
|
||||
@@ -4,10 +4,11 @@ from typing import Any
|
||||
import psycopg2
|
||||
from psycopg2.extensions import connection
|
||||
|
||||
from unstract.connectors.databases.psycopg_handler import PsycoPgHandler
|
||||
from unstract.connectors.databases.unstract_db import UnstractDB
|
||||
|
||||
|
||||
class PostgreSQL(UnstractDB):
|
||||
class PostgreSQL(UnstractDB, PsycoPgHandler):
|
||||
def __init__(self, settings: dict[str, Any]):
|
||||
super().__init__("PostgreSQL")
|
||||
|
||||
@@ -71,3 +72,13 @@ class PostgreSQL(UnstractDB):
|
||||
options=f"-c search_path={self.schema}",
|
||||
)
|
||||
return con
|
||||
|
||||
def execute_query(
|
||||
self, engine: Any, sql_query: str, sql_values: Any, **kwargs: Any
|
||||
) -> None:
|
||||
PsycoPgHandler.execute_query(
|
||||
engine=engine,
|
||||
sql_query=sql_query,
|
||||
sql_values=sql_values,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from psycopg2 import errors as PsycopgError
|
||||
|
||||
from unstract.connectors.databases.exceptions import (
|
||||
FeatureNotSupportedException,
|
||||
InvalidSchemaException,
|
||||
InvalidSyntaxException,
|
||||
UnderfinedTableException,
|
||||
ValueTooLongException,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PsycoPgHandler:
|
||||
@staticmethod
|
||||
def execute_query(
|
||||
engine: Any, sql_query: str, sql_values: Any, database: Any
|
||||
) -> None:
|
||||
try:
|
||||
with engine.cursor() as cursor:
|
||||
if sql_values:
|
||||
cursor.execute(sql_query, sql_values)
|
||||
else:
|
||||
cursor.execute(sql_query)
|
||||
engine.commit()
|
||||
except PsycopgError.InvalidSchemaName as e:
|
||||
logger.error(f"Invalid schema in creating table: {e.pgerror}")
|
||||
raise InvalidSchemaException(detail=e.pgerror, database=database) from e
|
||||
except PsycopgError.UndefinedTable as e:
|
||||
logger.error(f"Undefined table in inserting: {e.pgerror}")
|
||||
raise UnderfinedTableException(detail=e.pgerror, database=database) from e
|
||||
except PsycopgError.SyntaxError as e:
|
||||
logger.error(f"Invalid syntax in creating/inserting data: {e.pgerror}")
|
||||
raise InvalidSyntaxException(detail=e.pgerror, database=database) from e
|
||||
except PsycopgError.FeatureNotSupported as e:
|
||||
logger.error(
|
||||
f"feature not supported in creating/inserting data: {e.pgerror}"
|
||||
)
|
||||
raise FeatureNotSupportedException(
|
||||
detail=e.pgerror, database=database
|
||||
) from e
|
||||
except (
|
||||
PsycopgError.StringDataRightTruncation,
|
||||
PsycopgError.InternalError_,
|
||||
) as e:
|
||||
logger.error(f"value too long for datatype: {e.pgerror}")
|
||||
raise ValueTooLongException(detail=e.pgerror, database=database) from e
|
||||
@@ -5,10 +5,11 @@ from typing import Any
|
||||
import psycopg2
|
||||
from psycopg2.extensions import connection
|
||||
|
||||
from unstract.connectors.databases.psycopg_handler import PsycoPgHandler
|
||||
from unstract.connectors.databases.unstract_db import UnstractDB
|
||||
|
||||
|
||||
class Redshift(UnstractDB):
|
||||
class Redshift(UnstractDB, PsycoPgHandler):
|
||||
def __init__(self, settings: dict[str, Any]):
|
||||
super().__init__("Redshift")
|
||||
|
||||
@@ -74,14 +75,13 @@ class Redshift(UnstractDB):
|
||||
str: _description_
|
||||
"""
|
||||
python_type = type(value)
|
||||
|
||||
mapping = {
|
||||
str: "VARCHAR(65535)",
|
||||
str: "SUPER",
|
||||
int: "BIGINT",
|
||||
float: "DOUBLE PRECISION",
|
||||
datetime.datetime: "TIMESTAMP",
|
||||
}
|
||||
return mapping.get(python_type, "VARCHAR(65535)")
|
||||
return mapping.get(python_type, "SUPER")
|
||||
|
||||
@staticmethod
|
||||
def get_create_table_query(table: str) -> str:
|
||||
@@ -91,3 +91,13 @@ class Redshift(UnstractDB):
|
||||
f"created_by VARCHAR(65535), created_at TIMESTAMP, "
|
||||
)
|
||||
return sql_query
|
||||
|
||||
def execute_query(
|
||||
self, engine: Any, sql_query: str, sql_values: Any, **kwargs: Any
|
||||
) -> None:
|
||||
PsycoPgHandler.execute_query(
|
||||
engine=engine,
|
||||
sql_query=sql_query,
|
||||
sql_values=sql_values,
|
||||
database=self.database,
|
||||
)
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import snowflake.connector
|
||||
import snowflake.connector.errors as SnowflakeError
|
||||
from snowflake.connector.connection import SnowflakeConnection
|
||||
|
||||
from unstract.connectors.databases.exceptions import SnowflakeProgrammingException
|
||||
from unstract.connectors.databases.unstract_db import UnstractDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SnowflakeDB(UnstractDB):
|
||||
def __init__(self, settings: dict[str, Any]):
|
||||
@@ -70,3 +75,22 @@ class SnowflakeDB(UnstractDB):
|
||||
f"created_by TEXT, created_at TIMESTAMP, "
|
||||
)
|
||||
return sql_query
|
||||
|
||||
def execute_query(
|
||||
self, engine: Any, sql_query: str, sql_values: Any, **kwargs: Any
|
||||
) -> None:
|
||||
try:
|
||||
with engine.cursor() as cursor:
|
||||
if sql_values:
|
||||
cursor.execute(sql_query, sql_values)
|
||||
else:
|
||||
cursor.execute(sql_query)
|
||||
engine.commit()
|
||||
except SnowflakeError.ProgrammingError as e:
|
||||
logger.error(
|
||||
f"snowflake programming error in crearing/inserting table: "
|
||||
f"{e.msg} {e.errno}"
|
||||
)
|
||||
raise SnowflakeProgrammingException(
|
||||
detail=e.msg, database=self.database
|
||||
) from e
|
||||
|
||||
@@ -68,7 +68,7 @@ class UnstractDB(UnstractConnector, ABC):
|
||||
try:
|
||||
self.get_engine()
|
||||
except Exception as e:
|
||||
raise ConnectorError(str(e))
|
||||
raise ConnectorError(str(e)) from e
|
||||
return True
|
||||
|
||||
def execute(self, query: str) -> Any:
|
||||
@@ -77,7 +77,7 @@ class UnstractDB(UnstractConnector, ABC):
|
||||
cursor.execute(query)
|
||||
return cursor.fetchall()
|
||||
except Exception as e:
|
||||
raise ConnectorError(str(e))
|
||||
raise ConnectorError(str(e)) from e
|
||||
|
||||
@staticmethod
|
||||
def sql_to_db_mapping(value: str) -> str:
|
||||
@@ -107,3 +107,9 @@ class UnstractDB(UnstractConnector, ABC):
|
||||
f"created_by TEXT, created_at TIMESTAMP, "
|
||||
)
|
||||
return sql_query
|
||||
|
||||
@abstractmethod
|
||||
def execute_query(
|
||||
self, engine: Any, sql_query: str, sql_values: Any, **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user