Fix/write-db-connector-error-handling (#330)

* fix connector insert

* error adding srtatrted

* error handling for postgres/snowflake

* testc cases for postgres/snowflake

* testc cases for postgres/snowflake/redshift

* small change

* test cases covered for postgres, redshift

* test cases covered for bigquery

* test cases covered for postgres,snowflake,redshoft,big query

* test cases covered for postgres,snowflake,redshoft,big query,mssql

* test cases covered for postgres,snowflake,redshoft,big query,mssql,mariadb,mysql

* adding lock file

* addig celery change

* test connector changes

* removeing test connectionm

* removeing Keynotfound exception

* removeing Keynotfound exception

* refactoring db_utils

* sample.test.env changes

* fixing celery part on test cases

* fixing celery part on test cases

* refactoring db_utils

* refactor connectors

* refactor connectors

* refactor big query connectors

* small change

* small change in db_utils

* removing print in database_utils

* PR review

* small change

* small change

* PR comment change

---------

Co-authored-by: Ritwik G <100672805+ritwik-g@users.noreply.github.com>
This commit is contained in:
Kirtiman Mishra
2024-06-10 10:21:58 +05:30
committed by GitHub
parent dd79450d93
commit 2b6cbe4d4b
27 changed files with 969 additions and 122 deletions

1
.gitignore vendored
View File

@@ -132,6 +132,7 @@ celerybeat.pid
*.sage.py
# Environments
test*.env
.env
.env.export
.venv*

16
backend/pdm.lock generated
View File

@@ -3453,6 +3453,20 @@ files = [
{file = "pytest-8.2.2.tar.gz", hash = "sha256:de4bb8104e201939ccdc688b27a89a7be2079b22e2bd2b07f806b6ba71117977"},
]
[[package]]
name = "pytest-dotenv"
version = "0.5.2"
summary = "A py.test plugin that parses environment files before running tests"
groups = ["test"]
dependencies = [
"pytest>=5.0.0",
"python-dotenv>=0.9.1",
]
files = [
{file = "pytest-dotenv-0.5.2.tar.gz", hash = "sha256:2dc6c3ac6d8764c71c6d2804e902d0ff810fa19692e95fe138aefc9b1aa73732"},
{file = "pytest_dotenv-0.5.2-py3-none-any.whl", hash = "sha256:40a2cece120a213898afaa5407673f6bd924b1fa7eafce6bda0e8abffe2f710f"},
]
[[package]]
name = "python-crontab"
version = "3.1.0"
@@ -3484,7 +3498,7 @@ name = "python-dotenv"
version = "1.0.0"
requires_python = ">=3.8"
summary = "Read key-value pairs from a .env file and set them as environment variables"
groups = ["default"]
groups = ["default", "test"]
files = [
{file = "python-dotenv-1.0.0.tar.gz", hash = "sha256:a8df96034aae6d2d50a4ebe8216326c61c3eb64836776504fcca410e5937a3ba"},
{file = "python_dotenv-1.0.0-py3-none-any.whl", hash = "sha256:f5971a9226b701070a4bf2c38c89e5a3f0d64de8debda981d1db98583009122a"},

View File

@@ -66,8 +66,13 @@ deploy = [
]
test = [
"pytest>=8.0.1",
"pytest-dotenv==0.5.2",
]
[tool.pytest.ini_options]
env_files = "test.env" # Load env from particular env file
addopts = "-s"
[tool.pdm.scripts]
# Commands for backend
backend.cmd = "./entrypoint.sh"

40
backend/sample.test.env Normal file
View File

@@ -0,0 +1,40 @@
DB_HOST='unstract-db'
DB_USER='unstract_dev'
DB_PASSWORD='unstract_pass'
DB_NAME='unstract_db'
DB_PORT=5432
REDSHIFT_HOST=
REDSHIFT_PORT="5439"
REDSHIFT_DB=
REDSHIFT_USER=
REDSHIFT_PASSWORD=
BIGQUERY_CREDS=
SNOWFLAKE_USER=
SNOWFLAKE_PASSWORD=
SNOWFLAKE_ACCOUNT=
SNOWFLAKE_ROLE=
SNOWFLAKE_DB=
SNOWFLAKE_SCHEMA=
SNOWFLAKE_WAREHOUSE=
MSSQL_SERVER=
MSSQL_PORT=
MSSQL_PASSWORD=
MSSQL_DB=
MSSQL_USER=
MYSQL_SERVER=
MYSQL_PORT=
MYSQL_PASSWORD=
MYSQL_DB=
MYSQL_USER=
MARIADB_SERVER=
MARIADB_PORT=
MARIADB_PASSWORD=
MARIADB_DB=
MARIADB_USER=

View File

@@ -5,8 +5,9 @@ class TableColumns:
class DBConnectionClass:
SNOWFLAKE = "SnowflakeConnection"
BIGQUERY = "Client"
SNOWFLAKE = "SnowflakeDB"
BIGQUERY = "BigQuery"
MSSQL = "MSSQL"
class Snowflake:

View File

@@ -8,22 +8,27 @@ from utils.constants import Common
from workflow_manager.endpoint.constants import (
BigQuery,
DBConnectionClass,
Snowflake,
TableColumns,
)
from workflow_manager.endpoint.exceptions import BigQueryTableNotFound
from workflow_manager.endpoint.db_connector_helper import DBConnectorQueryHelper
from workflow_manager.endpoint.exceptions import (
BigQueryTableNotFound,
UnstractDBException,
)
from workflow_manager.workflow.enums import AgentName, ColumnModes
from unstract.connectors.databases import connectors as db_connectors
from unstract.connectors.databases.exceptions import UnstractDBConnectorException
from unstract.connectors.databases.unstract_db import UnstractDB
from unstract.connectors.exceptions import ConnectorError
logger = logging.getLogger(__name__)
class DatabaseUtils:
@staticmethod
def make_sql_values_for_query(
values: dict[str, Any], column_types: dict[str, str], cls: Any = None
def get_sql_values_for_query(
values: dict[str, Any], column_types: dict[str, str], cls_name: str
) -> dict[str, str]:
"""Making Sql Columns and Values for Query.
@@ -51,32 +56,29 @@ class DatabaseUtils:
insert into the table accordingly
"""
sql_values: dict[str, Any] = {}
for column in values:
if cls == DBConnectionClass.SNOWFLAKE:
if cls_name == DBConnectionClass.SNOWFLAKE:
col = column.lower()
type_x = column_types[col]
if type_x in Snowflake.COLUMN_TYPES:
sql_values[column] = f"'{values[column]}'"
elif type_x == "VARIANT":
if type_x == "VARIANT":
values[column] = values[column].replace("'", "\\'")
sql_values[column] = f"parse_json($${values[column]}$$)"
else:
sql_values[column] = f"{values[column]}"
elif cls == DBConnectionClass.BIGQUERY:
elif cls_name == DBConnectionClass.BIGQUERY:
col = column.lower()
type_x = column_types[col]
if type_x in BigQuery.COLUMN_TYPES:
sql_values[column] = f"{type_x}('{values[column]}')"
sql_values[column] = f"{type_x}({values[column]})"
else:
sql_values[column] = f"'{values[column]}'"
sql_values[column] = f"{values[column]}"
else:
# Default to Other SQL DBs
# TODO: Handle numeric types with no quotes
sql_values[column] = f"'{values[column]}'"
sql_values[column] = f"{values[column]}"
if column_types.get("id"):
uuid_id = str(uuid.uuid4())
sql_values["id"] = f"'{uuid_id}'"
sql_values["id"] = f"{uuid_id}"
return sql_values
@staticmethod
@@ -96,7 +98,7 @@ class DatabaseUtils:
@staticmethod
def get_column_types(
cls: Any,
cls_name: Any,
table_name: str,
connector_id: str,
connector_settings: dict[str, Any],
@@ -118,7 +120,7 @@ class DatabaseUtils:
"""
column_types: dict[str, str] = {}
try:
if cls == DBConnectionClass.SNOWFLAKE:
if cls_name == DBConnectionClass.SNOWFLAKE:
query = f"describe table {table_name}"
results = DatabaseUtils.execute_and_fetch_data(
connector_id=connector_id,
@@ -127,7 +129,7 @@ class DatabaseUtils:
)
for column in results:
column_types[column[0].lower()] = column[1].split("(")[0]
elif cls == DBConnectionClass.BIGQUERY:
elif cls_name == DBConnectionClass.BIGQUERY:
bigquery_table_name = str.lower(table_name).split(".")
if len(bigquery_table_name) != BigQuery.TABLE_NAME_SIZE:
raise BigQueryTableNotFound()
@@ -223,8 +225,8 @@ class DatabaseUtils:
return values
@staticmethod
def get_sql_columns_and_values_for_query(
engine: Any,
def get_sql_query_data(
cls_name: str,
connector_id: str,
connector_settings: dict[str, Any],
table_name: str,
@@ -234,7 +236,6 @@ class DatabaseUtils:
provided values and table schema.
Args:
engine (Any): The database engine.
connector_id: The connector id of the connector provided
connector_settings: Connector settings provided by user
table_name (str): The name of the target table for the insert query.
@@ -252,26 +253,26 @@ class DatabaseUtils:
- For other SQL databases, it uses default SQL generation
based on column types.
"""
class_name = engine.__class__.__name__
column_types: dict[str, str] = DatabaseUtils.get_column_types(
cls=class_name,
cls_name=cls_name,
table_name=table_name,
connector_id=connector_id,
connector_settings=connector_settings,
)
sql_columns_and_values = DatabaseUtils.make_sql_values_for_query(
sql_columns_and_values = DatabaseUtils.get_sql_values_for_query(
values=values,
column_types=column_types,
cls=class_name,
cls_name=cls_name,
)
return sql_columns_and_values
@staticmethod
def execute_write_query(
db_class: UnstractDB,
engine: Any,
table_name: str,
sql_keys: list[str],
sql_values: list[str],
sql_values: Any,
) -> None:
"""Execute Insert Query.
@@ -284,22 +285,30 @@ class DatabaseUtils:
- Snowflake does not support INSERT INTO ... VALUES ...
syntax when VARIANT columns are present (JSON).
So we need to use INSERT INTO ... SELECT ... syntax
- sql values can contain data with single quote. It needs to
"""
sql = (
f"INSERT INTO {table_name} ({','.join(sql_keys)}) "
f"SELECT {','.join(sql_values)}"
cls_name = db_class.__class__.__name__
sql = DBConnectorQueryHelper.build_sql_insert_query(
cls_name=cls_name, table_name=table_name, sql_keys=sql_keys
)
logger.debug(f"insertng into table with: {sql} query")
logger.debug(f"inserting into table {table_name} with: {sql} query")
sql_values = DBConnectorQueryHelper.prepare_sql_values(
cls_name=cls_name, sql_values=sql_values, sql_keys=sql_keys
)
logger.debug(f"sql_values: {sql_values}")
try:
if hasattr(engine, "cursor"):
with engine.cursor() as cursor:
cursor.execute(sql)
engine.commit()
else:
engine.query(sql)
except Exception as e:
logger.error(f"Error while writing data: {str(e)}")
raise e
db_class.execute_query(
engine=engine,
sql_query=sql,
sql_values=sql_values,
table_name=table_name,
sql_keys=sql_keys,
)
except UnstractDBConnectorException as e:
raise UnstractDBException(detail=e.detail) from e
logger.debug(f"sucessfully inserted into table {table_name} with: {sql} query")
@staticmethod
def get_db_class(
@@ -315,7 +324,10 @@ class DatabaseUtils:
) -> Any:
connector = db_connectors[connector_id][Common.METADATA][Common.CONNECTOR]
connector_class: UnstractDB = connector(connector_settings)
return connector_class.execute(query=query)
try:
return connector_class.execute(query=query)
except ConnectorError as e:
raise UnstractDBException(detail=e.message) from e
@staticmethod
def create_table_if_not_exists(
@@ -328,7 +340,6 @@ class DatabaseUtils:
Args:
class_name (UnstractDB): Type of Unstract DB connector
engine (Any): _description_
table_name (str): _description_
database_entry (dict[str, Any]): _description_
@@ -338,62 +349,9 @@ class DatabaseUtils:
sql = DBConnectorQueryHelper.create_table_query(
conn_cls=db_class, table=table_name, database_entry=database_entry
)
logger.debug(f"creating table with: {sql} query")
logger.debug(f"creating table {table_name} with: {sql} query")
try:
if hasattr(engine, "cursor"):
with engine.cursor() as cursor:
cursor.execute(sql)
engine.commit()
else:
engine.query(sql)
except Exception as e:
logger.error(f"Error while creating table: {str(e)}")
raise e
class DBConnectorQueryHelper:
"""A class that helps to generate query for connector table operations."""
@staticmethod
def create_table_query(
conn_cls: UnstractDB, table: str, database_entry: dict[str, Any]
) -> Any:
sql_query = ""
"""Generate a SQL query to create a table, based on the provided
database entry.
Args:
conn_cls (str): The database connector class.
Should be one of 'BIGQUERY', 'SNOWFLAKE', or other.
table (str): The name of the table to be created.
database_entry (dict[str, Any]):
A dictionary containing column names as keys
and their corresponding values.
These values are used to determine the data types,
for the columns in the table.
Returns:
str: A SQL query string to create a table with the specified name,
and column definitions.
Note:
- Each conn_cls have it's implementation for SQL create table query
Based on the implementation, a base SQL create table query will be
created containing Permanent columns
- Each conn_cls also has a mapping to convert python datatype to
corresponding column type (string, VARCHAR etc)
- keys in database_entry will be converted to column type, and
column values will be the valus in database_entry
- base SQL create table will be appended based column type and
values, and generates a complete SQL create table query
"""
create_table_query = conn_cls.get_create_table_query(table=table)
sql_query += create_table_query
for key, val in database_entry.items():
if key not in TableColumns.PERMANENT_COLUMNS:
sql_type = conn_cls.sql_to_db_mapping(val)
sql_query += f"{key} {sql_type}, "
return sql_query.rstrip(", ") + ");"
db_class.execute_query(engine=engine, sql_query=sql, sql_values=None)
except UnstractDBConnectorException as e:
raise UnstractDBException(detail=e.detail) from e
logger.debug(f"successfully created table {table_name} with: {sql} query")

View File

@@ -0,0 +1,77 @@
from typing import Any
from google.cloud import bigquery
from workflow_manager.endpoint.constants import DBConnectionClass, TableColumns
from unstract.connectors.databases.unstract_db import UnstractDB
class DBConnectorQueryHelper:
"""A class that helps to generate query for connector table operations."""
@staticmethod
def create_table_query(
conn_cls: UnstractDB, table: str, database_entry: dict[str, Any]
) -> Any:
sql_query = ""
"""Generate a SQL query to create a table, based on the provided
database entry.
Args:
conn_cls (str): The database connector class.
Should be one of 'BIGQUERY', 'SNOWFLAKE', or other.
table (str): The name of the table to be created.
database_entry (dict[str, Any]):
A dictionary containing column names as keys
and their corresponding values.
These values are used to determine the data types,
for the columns in the table.
Returns:
str: A SQL query string to create a table with the specified name,
and column definitions.
Note:
- Each conn_cls have it's implementation for SQL create table query
Based on the implementation, a base SQL create table query will be
created containing Permanent columns
- Each conn_cls also has a mapping to convert python datatype to
corresponding column type (string, VARCHAR etc)
- keys in database_entry will be converted to column type, and
column values will be the valus in database_entry
- base SQL create table will be appended based column type and
values, and generates a complete SQL create table query
"""
create_table_query = conn_cls.get_create_table_query(table=table)
sql_query += create_table_query
for key, val in database_entry.items():
if key not in TableColumns.PERMANENT_COLUMNS:
sql_type = conn_cls.sql_to_db_mapping(val)
sql_query += f"{key} {sql_type}, "
return sql_query.rstrip(", ") + ");"
@staticmethod
def build_sql_insert_query(
cls_name: str, table_name: str, sql_keys: list[str]
) -> str:
keys_str = ",".join(sql_keys)
if cls_name == DBConnectionClass.BIGQUERY:
values_placeholder = ",".join(["@" + key for key in sql_keys])
else:
values_placeholder = ",".join(["%s" for _ in sql_keys])
return f"INSERT INTO {table_name} ({keys_str}) VALUES ({values_placeholder})"
@staticmethod
def prepare_sql_values(cls_name: str, sql_values: Any, sql_keys: list[str]) -> Any:
if cls_name == DBConnectionClass.MSSQL:
return tuple(sql_values)
elif cls_name == DBConnectionClass.BIGQUERY:
query_parameters = [
bigquery.ScalarQueryParameter(key, "STRING", value)
for key, value in zip(sql_keys, sql_values)
]
return bigquery.QueryJobConfig(query_parameters=query_parameters)
return sql_values

View File

@@ -214,14 +214,16 @@ class DestinationConnector(BaseConnector):
table_name=table_name,
database_entry=values,
)
sql_columns_and_values = DatabaseUtils.get_sql_columns_and_values_for_query(
engine=engine,
cls_name = db_class.__class__.__name__
sql_columns_and_values = DatabaseUtils.get_sql_query_data(
cls_name=cls_name,
connector_id=connector_instance.connector_id,
connector_settings=connector_settings,
table_name=table_name,
values=values,
)
DatabaseUtils.execute_write_query(
db_class=db_class,
engine=engine,
table_name=table_name,
sql_keys=list(sql_columns_and_values.keys()),

View File

@@ -69,3 +69,11 @@ class BigQueryTableNotFound(APIException):
"Please enter correct correct bigquery table in the form "
"{table}.{schema}.{database}."
)
class UnstractDBException(APIException):
default_detail = "Error creating/inserting to database. "
def __init__(self, detail: str = default_detail) -> None:
status_code = 500
super().__init__(detail=detail, code=status_code)

View File

@@ -0,0 +1,3 @@
from backend.celery import app as celery_app
__all__ = ["celery_app"]

View File

@@ -0,0 +1,160 @@
import datetime
import json
import os
from typing import Any
import pytest # type: ignore
from dotenv import load_dotenv
from unstract.connectors.databases.bigquery import BigQuery
from unstract.connectors.databases.mariadb import MariaDB
from unstract.connectors.databases.mssql import MSSQL
from unstract.connectors.databases.mysql import MySQL
from unstract.connectors.databases.postgresql import PostgreSQL
from unstract.connectors.databases.redshift import Redshift
from unstract.connectors.databases.snowflake import SnowflakeDB
from unstract.connectors.databases.unstract_db import UnstractDB
load_dotenv("test.env")
class BaseTestDB:
@pytest.fixture(autouse=True)
def base_setup(self) -> None:
self.postgres_creds = {
"user": os.getenv("DB_USER"),
"password": os.getenv("DB_PASSWORD"),
"host": os.getenv("DB_HOST"),
"port": os.getenv("DB_PORT"),
"database": os.getenv("DB_NAME"),
}
self.redshift_creds = {
"user": os.getenv("REDSHIFT_USER"),
"password": os.getenv("REDSHIFT_PASSWORD"),
"host": os.getenv("REDSHIFT_HOST"),
"port": os.getenv("REDSHIFT_PORT"),
"database": os.getenv("REDSHIFT_DB"),
}
self.snowflake_creds = {
"user": os.getenv("SNOWFLAKE_USER"),
"password": os.getenv("SNOWFLAKE_PASSWORD"),
"account": os.getenv("SNOWFLAKE_ACCOUNT"),
"role": os.getenv("SNOWFLAKE_ROLE"),
"database": os.getenv("SNOWFLAKE_DB"),
"schema": os.getenv("SNOWFLAKE_SCHEMA"),
"warehouse": os.getenv("SNOWFLAKE_WAREHOUSE"),
}
self.mssql_creds = {
"user": os.getenv("MSSQL_USER"),
"password": os.getenv("MSSQL_PASSWORD"),
"server": os.getenv("MSSQL_SERVER"),
"port": os.getenv("MSSQL_PORT"),
"database": os.getenv("MSSQL_DB"),
}
self.mysql_creds = {
"user": os.getenv("MYSQL_USER"),
"password": os.getenv("MYSQL_PASSWORD"),
"host": os.getenv("MYSQL_SERVER"),
"port": os.getenv("MYSQL_PORT"),
"database": os.getenv("MYSQL_DB"),
}
self.mariadb_creds = {
"user": os.getenv("MARIADB_USER"),
"password": os.getenv("MARIADB_PASSWORD"),
"host": os.getenv("MARIADB_SERVER"),
"port": os.getenv("MARIADB_PORT"),
"database": os.getenv("MARIADB_DB"),
}
self.database_entry = {
"created_by": "Unstract/DBWriter",
"created_at": datetime.datetime(2024, 5, 20, 7, 46, 57, 307998),
"data": '{"input_file": "simple.pdf", "result": "report"}',
}
valid_schema_name = "public"
invalid_schema_name = "public_1"
self.valid_postgres_creds = {**self.postgres_creds, "schema": valid_schema_name}
self.invalid_postgres_creds = {
**self.postgres_creds,
"schema": invalid_schema_name,
}
self.valid_redshift_creds = {**self.redshift_creds, "schema": valid_schema_name}
self.invalid_redshift_creds = {
**self.redshift_creds,
"schema": invalid_schema_name,
}
self.invalid_syntax_table_name = "invalid-syntax.name.test_output"
self.invalid_wrong_table_name = "database.schema.test_output"
self.valid_table_name = "test_output"
bigquery_json_str = os.getenv("BIGQUERY_CREDS", "{}")
self.bigquery_settings = json.loads(bigquery_json_str)
self.bigquery_settings["json_credentials"] = bigquery_json_str
self.valid_bigquery_table_name = "pandoras-tamer.bigquery_test.bigquery_output"
self.invalid_snowflake_db = {**self.snowflake_creds, "database": "invalid"}
self.invalid_snowflake_schema = {**self.snowflake_creds, "schema": "invalid"}
self.invalid_snowflake_warehouse = {
**self.snowflake_creds,
"warehouse": "invalid",
}
# Gets all valid db instances except
# Bigquery (table name needs to be writted separately for bigquery)
@pytest.fixture(
params=[
("valid_postgres_creds", PostgreSQL),
("snowflake_creds", SnowflakeDB),
("mssql_creds", MSSQL),
("mysql_creds", MySQL),
("mariadb_creds", MariaDB),
("valid_redshift_creds", Redshift),
]
)
def valid_dbs_instance(self, request: Any) -> Any:
return self.get_db_instance(request=request)
# Gets all valid db instances except:
# Bigquery (table name needs to be writted separately for bigquery)
# Redshift (can't process more than 64KB character type)
@pytest.fixture(
params=[
("valid_postgres_creds", PostgreSQL),
("snowflake_creds", SnowflakeDB),
("mssql_creds", MSSQL),
("mysql_creds", MySQL),
("mariadb_creds", MariaDB),
]
)
def valid_dbs_instance_to_handle_large_doc(self, request: Any) -> Any:
return self.get_db_instance(request=request)
def get_db_instance(self, request: Any) -> UnstractDB:
creds_name, db_class = request.param
creds = getattr(self, creds_name)
if not creds:
pytest.fail(f"Unknown credentials: {creds_name}")
db_instance = db_class(settings=creds)
return db_instance
# Gets all invalid-db instances for postgres, redshift:
@pytest.fixture(
params=[
("invalid_postgres_creds", PostgreSQL),
("invalid_redshift_creds", Redshift),
]
)
def invalid_dbs_instance(self, request: Any) -> Any:
return self.get_db_instance(request=request)
@pytest.fixture
def valid_bigquery_db_instance(self) -> Any:
return BigQuery(settings=self.bigquery_settings)
# Gets all invalid-db instances for snowflake:
@pytest.fixture(
params=[
("invalid_snowflake_db", SnowflakeDB),
("invalid_snowflake_schema", SnowflakeDB),
("invalid_snowflake_warehouse", SnowflakeDB),
]
)
def invalid_snowflake_db_instance(self, request: Any) -> Any:
return self.get_db_instance(request=request)

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,97 @@
import pytest # type: ignore
from workflow_manager.endpoint.database_utils import DatabaseUtils
from workflow_manager.endpoint.exceptions import UnstractDBException
from unstract.connectors.databases.unstract_db import UnstractDB
from .base_test_db import BaseTestDB
class TestCreateTableIfNotExists(BaseTestDB):
def test_create_table_if_not_exists_valid(
self, valid_dbs_instance: UnstractDB
) -> None:
engine = valid_dbs_instance.get_engine()
result = DatabaseUtils.create_table_if_not_exists(
db_class=valid_dbs_instance,
engine=engine,
table_name=self.valid_table_name,
database_entry=self.database_entry,
)
assert result is None
def test_create_table_if_not_exists_bigquery_valid(
self, valid_bigquery_db_instance: UnstractDB
) -> None:
engine = valid_bigquery_db_instance.get_engine()
result = DatabaseUtils.create_table_if_not_exists(
db_class=valid_bigquery_db_instance,
engine=engine,
table_name=self.valid_bigquery_table_name,
database_entry=self.database_entry,
)
assert result is None
def test_create_table_if_not_exists_invalid_schema(
self, invalid_dbs_instance: UnstractDB
) -> None:
engine = invalid_dbs_instance.get_engine()
with pytest.raises(UnstractDBException):
DatabaseUtils.create_table_if_not_exists(
db_class=invalid_dbs_instance,
engine=engine,
table_name=self.valid_table_name,
database_entry=self.database_entry,
)
def test_create_table_if_not_exists_invalid_syntax(
self, valid_dbs_instance: UnstractDB
) -> None:
engine = valid_dbs_instance.get_engine()
with pytest.raises(UnstractDBException):
DatabaseUtils.create_table_if_not_exists(
db_class=valid_dbs_instance,
engine=engine,
table_name=self.invalid_syntax_table_name,
database_entry=self.database_entry,
)
def test_create_table_if_not_exists_wrong_table_name(
self, valid_dbs_instance: UnstractDB
) -> None:
engine = valid_dbs_instance.get_engine()
with pytest.raises(UnstractDBException):
DatabaseUtils.create_table_if_not_exists(
db_class=valid_dbs_instance,
engine=engine,
table_name=self.invalid_wrong_table_name,
database_entry=self.database_entry,
)
def test_create_table_if_not_exists_feature_not_supported(
self, invalid_dbs_instance: UnstractDB
) -> None:
engine = invalid_dbs_instance.get_engine()
with pytest.raises(UnstractDBException):
DatabaseUtils.create_table_if_not_exists(
db_class=invalid_dbs_instance,
engine=engine,
table_name=self.invalid_wrong_table_name,
database_entry=self.database_entry,
)
def test_create_table_if_not_exists_invalid_snowflake_db(
self, invalid_snowflake_db_instance: UnstractDB
) -> None:
engine = invalid_snowflake_db_instance.get_engine()
with pytest.raises(UnstractDBException):
DatabaseUtils.create_table_if_not_exists(
db_class=invalid_snowflake_db_instance,
engine=engine,
table_name=self.invalid_wrong_table_name,
database_entry=self.database_entry,
)
if __name__ == "__main__":
pytest.main()

View File

@@ -0,0 +1,169 @@
import os
import uuid
from typing import Any
import pytest # type: ignore
from workflow_manager.endpoint.database_utils import DatabaseUtils
from workflow_manager.endpoint.exceptions import UnstractDBException
from unstract.connectors.databases.redshift import Redshift
from unstract.connectors.databases.unstract_db import UnstractDB
from .base_test_db import BaseTestDB
class TestExecuteWriteQuery(BaseTestDB):
@pytest.fixture(autouse=True)
def setup(self, base_setup: Any) -> None:
self.sql_columns_and_values = {
"created_by": "Unstract/DBWriter",
"created_at": "2024-05-20 10:36:25.362609",
"data": '{"input_file": "simple.pdf", "result": "report"}',
"id": str(uuid.uuid4()),
}
def test_execute_write_query_valid(self, valid_dbs_instance: Any) -> None:
engine = valid_dbs_instance.get_engine()
result = DatabaseUtils.execute_write_query(
db_class=valid_dbs_instance,
engine=engine,
table_name=self.valid_table_name,
sql_keys=list(self.sql_columns_and_values.keys()),
sql_values=list(self.sql_columns_and_values.values()),
)
assert result is None
def test_execute_write_query_invalid_schema(
self, invalid_dbs_instance: Any
) -> None:
engine = invalid_dbs_instance.get_engine()
with pytest.raises(UnstractDBException):
DatabaseUtils.execute_write_query(
db_class=invalid_dbs_instance,
engine=engine,
table_name=self.valid_table_name,
sql_keys=list(self.sql_columns_and_values.keys()),
sql_values=list(self.sql_columns_and_values.values()),
)
def test_execute_write_query_invalid_syntax(self, valid_dbs_instance: Any) -> None:
engine = valid_dbs_instance.get_engine()
with pytest.raises(UnstractDBException):
DatabaseUtils.execute_write_query(
db_class=valid_dbs_instance,
engine=engine,
table_name=self.invalid_syntax_table_name,
sql_keys=list(self.sql_columns_and_values.keys()),
sql_values=list(self.sql_columns_and_values.values()),
)
def test_execute_write_query_feature_not_supported(
self, invalid_dbs_instance: Any
) -> None:
engine = invalid_dbs_instance.get_engine()
with pytest.raises(UnstractDBException):
DatabaseUtils.execute_write_query(
db_class=invalid_dbs_instance,
engine=engine,
table_name=self.invalid_wrong_table_name,
sql_keys=list(self.sql_columns_and_values.keys()),
sql_values=list(self.sql_columns_and_values.values()),
)
def load_text_to_sql_values(self) -> dict[str, Any]:
file_path = os.path.join(os.path.dirname(__file__), "static", "large_doc.txt")
with open(file_path, encoding="utf-8") as file:
content = file.read()
sql_columns_and_values = self.sql_columns_and_values.copy()
sql_columns_and_values["data"] = content
return sql_columns_and_values
@pytest.fixture
def valid_redshift_db_instance(self) -> Any:
return Redshift(self.valid_redshift_creds)
def test_execute_write_query_datatype_too_large_redshift(
self, valid_redshift_db_instance: Any
) -> None:
engine = valid_redshift_db_instance.get_engine()
sql_columns_and_values = self.load_text_to_sql_values()
with pytest.raises(UnstractDBException):
DatabaseUtils.execute_write_query(
db_class=valid_redshift_db_instance,
engine=engine,
table_name=self.valid_table_name,
sql_keys=list(sql_columns_and_values.keys()),
sql_values=list(sql_columns_and_values.values()),
)
def test_execute_write_query_bigquery_valid(
self, valid_bigquery_db_instance: Any
) -> None:
engine = valid_bigquery_db_instance.get_engine()
result = DatabaseUtils.execute_write_query(
db_class=valid_bigquery_db_instance,
engine=engine,
table_name=self.valid_bigquery_table_name,
sql_keys=list(self.sql_columns_and_values.keys()),
sql_values=list(self.sql_columns_and_values.values()),
)
assert result is None
def test_execute_write_query_wrong_table_name(
self, valid_dbs_instance: UnstractDB
) -> None:
engine = valid_dbs_instance.get_engine()
with pytest.raises(UnstractDBException):
DatabaseUtils.execute_write_query(
db_class=valid_dbs_instance,
engine=engine,
table_name=self.invalid_wrong_table_name,
sql_keys=list(self.sql_columns_and_values.keys()),
sql_values=list(self.sql_columns_and_values.values()),
)
def test_execute_write_query_bigquery_large_doc(
self, valid_bigquery_db_instance: Any
) -> None:
engine = valid_bigquery_db_instance.get_engine()
sql_columns_and_values = self.load_text_to_sql_values()
result = DatabaseUtils.execute_write_query(
db_class=valid_bigquery_db_instance,
engine=engine,
table_name=self.valid_bigquery_table_name,
sql_keys=list(sql_columns_and_values.keys()),
sql_values=list(sql_columns_and_values.values()),
)
assert result is None
def test_execute_write_query_invalid_snowflake_db(
self, invalid_snowflake_db_instance: UnstractDB
) -> None:
engine = invalid_snowflake_db_instance.get_engine()
with pytest.raises(UnstractDBException):
DatabaseUtils.execute_write_query(
db_class=invalid_snowflake_db_instance,
engine=engine,
table_name=self.invalid_wrong_table_name,
sql_keys=list(self.sql_columns_and_values.keys()),
sql_values=list(self.sql_columns_and_values.values()),
)
# Make this function at last to cover all large doc
def test_execute_write_query_large_doc(
self, valid_dbs_instance_to_handle_large_doc: Any
) -> None:
engine = valid_dbs_instance_to_handle_large_doc.get_engine()
sql_columns_and_values = self.load_text_to_sql_values()
result = DatabaseUtils.execute_write_query(
db_class=valid_dbs_instance_to_handle_large_doc,
engine=engine,
table_name=self.valid_table_name,
sql_keys=list(sql_columns_and_values.keys()),
sql_values=list(sql_columns_and_values.values()),
)
assert result is None
if __name__ == "__main__":
pytest.main()

View File

@@ -1,14 +1,22 @@
import datetime
import json
import logging
import os
from typing import Any
import google.api_core.exceptions
from google.cloud import bigquery
from google.cloud.bigquery import Client
from unstract.connectors.databases.exceptions import (
BigQueryForbiddenException,
BigQueryNotFoundException,
)
from unstract.connectors.databases.unstract_db import UnstractDB
from unstract.connectors.exceptions import ConnectorError
logger = logging.getLogger(__name__)
class BigQuery(UnstractDB):
def __init__(self, settings: dict[str, Any]):
@@ -87,3 +95,24 @@ class BigQuery(UnstractDB):
f"created_by string, created_at TIMESTAMP, "
)
return sql_query
def execute_query(
self, engine: Any, sql_query: str, sql_values: Any, **kwargs: Any
) -> None:
table_name = str(kwargs.get("table_name"))
try:
if sql_values:
engine.query(sql_query, job_config=sql_values)
else:
engine.query(sql_query)
except google.api_core.exceptions.Forbidden as e:
logger.error(f"Forbidden exception in creating/inserting data: {str(e)}")
raise BigQueryForbiddenException(
detail=e.message,
table_name=table_name,
) from e
except google.api_core.exceptions.NotFound as e:
logger.error(f"Resource not found in creating/inserting table: {str(e)}")
raise BigQueryNotFoundException(
detail=e.message, table_name=table_name
) from e

View File

@@ -0,0 +1,92 @@
from typing import Any
from unstract.connectors.exceptions import ConnectorBaseException
class UnstractDBConnectorException(ConnectorBaseException):
"""Base class for database-related exceptions from Unstract connectors."""
def __init__(
self,
detail: Any,
*args: Any,
**kwargs: Any,
) -> None:
default_detail = "Error creating/inserting to database. "
user_message = default_detail if not detail else detail
super().__init__(*args, user_message=user_message, **kwargs)
self.detail = user_message
class InvalidSyntaxException(UnstractDBConnectorException):
def __init__(self, detail: Any, database: Any) -> None:
default_detail = (
f"Error creating/writing to {database}. Syntax incorrect. "
f"Please check your table-name or schema. "
)
super().__init__(detail=default_detail + detail)
class InvalidSchemaException(UnstractDBConnectorException):
def __init__(self, detail: Any, database: str) -> None:
default_detail = f"Error creating/writing to {database}. Schema not valid. "
super().__init__(detail=default_detail + detail)
class UnderfinedTableException(UnstractDBConnectorException):
def __init__(self, detail: Any, database: str) -> None:
default_detail = (
f"Error creating/writing to {database}. Undefined table. "
f"Please check your table-name or schema. "
)
super().__init__(detail=default_detail + detail)
class ValueTooLongException(UnstractDBConnectorException):
def __init__(self, detail: Any, database: str) -> None:
default_detail = (
f"Error creating/writing to {database}. "
f"Size of the inserted data exceeds the limit provided by the database. "
)
super().__init__(detail=default_detail + detail)
class FeatureNotSupportedException(UnstractDBConnectorException):
def __init__(self, detail: Any, database: str) -> None:
default_detail = (
f"Error creating/writing to {database}. "
f"Feature not supported sql error. "
)
super().__init__(detail=default_detail + detail)
class SnowflakeProgrammingException(UnstractDBConnectorException):
def __init__(self, detail: Any, database: str) -> None:
default_detail = (
f"Error creating/writing to {database}. "
f"Please check your snowflake credentials. "
)
super().__init__(default_detail + detail)
class BigQueryForbiddenException(UnstractDBConnectorException):
def __init__(self, detail: Any, table_name: str) -> None:
default_detail = (
f"Error creating/writing to {table_name}. "
f"Access forbidden in bigquery. Please check your permissions. "
)
super().__init__(detail=default_detail + detail)
class BigQueryNotFoundException(UnstractDBConnectorException):
def __init__(self, detail: str, table_name: str) -> None:
default_detail = (
f"Error creating/writing to {table_name}. "
f"The requested resource was not found. "
)
super().__init__(detail=default_detail + detail)

View File

@@ -0,0 +1,20 @@
from typing import Any
class ExceptionHelper:
@staticmethod
def extract_byte_exception(e: Exception) -> Any:
"""_summary_
Extract error details from byte_exception.
Used by mssql
Args:
e (Exception): _description_
Returns:
Any: _description_
"""
error_message = str(e)
error_code, error_details = eval(error_message)
if isinstance(error_details, bytes):
error_details = error_details.decode("utf-8")
return error_details

View File

@@ -5,10 +5,11 @@ import pymysql
from pymysql.connections import Connection
from unstract.connectors.databases.mysql import MySQL
from unstract.connectors.databases.mysql_handler import MysqlHandler
from unstract.connectors.databases.unstract_db import UnstractDB
class MariaDB(UnstractDB):
class MariaDB(UnstractDB, MysqlHandler):
def __init__(self, settings: dict[str, Any]):
super().__init__("MariaDB")
@@ -74,3 +75,13 @@ class MariaDB(UnstractDB):
Mysql and Mariadb share same SQL column type
"""
return str(MySQL.sql_to_db_mapping(value=value))
def execute_query(
self, engine: Any, sql_query: str, sql_values: Any, **kwargs: Any
) -> None:
MysqlHandler.execute_query(
engine=engine,
sql_query=sql_query,
sql_values=sql_values,
database=self.database,
)

View File

@@ -1,11 +1,17 @@
import logging
import os
from typing import Any
import pymssql
from pymssql import Connection
import pymssql._pymssql as PyMssql
from pymssql import Connection # type: ignore
from unstract.connectors.databases.exceptions import InvalidSyntaxException
from unstract.connectors.databases.exceptions_helper import ExceptionHelper
from unstract.connectors.databases.unstract_db import UnstractDB
logger = logging.getLogger(__name__)
class MSSQL(UnstractDB):
def __init__(self, settings: dict[str, Any]):
@@ -49,16 +55,9 @@ class MSSQL(UnstractDB):
return True
def get_engine(self) -> Connection:
if self.port:
return pymssql.connect(
server=self.server,
port=self.port,
user=self.user,
password=self.password,
database=self.database,
)
return pymssql.connect(
return pymssql.connect( # type: ignore
server=self.server,
port=self.port,
user=self.user,
password=self.password,
database=self.database,
@@ -74,3 +73,22 @@ class MSSQL(UnstractDB):
f"created_by TEXT, created_at DATETIMEOFFSET, "
)
return sql_query
def execute_query(
self, engine: Any, sql_query: str, sql_values: Any, **kwargs: Any
) -> None:
try:
with engine.cursor() as cursor:
if sql_values:
cursor.execute(sql_query, sql_values)
else:
cursor.execute(sql_query)
engine.commit()
except (PyMssql.ProgrammingError, PyMssql.OperationalError) as e:
error_details = ExceptionHelper.extract_byte_exception(e=e)
logger.error(
f"Invalid syntax in creating/inserting mssql data: {error_details}"
)
raise InvalidSyntaxException(
detail=error_details, database=self.database
) from e

View File

@@ -5,10 +5,11 @@ from typing import Any
import pymysql
from pymysql.connections import Connection
from unstract.connectors.databases.mysql_handler import MysqlHandler
from unstract.connectors.databases.unstract_db import UnstractDB
class MySQL(UnstractDB):
class MySQL(UnstractDB, MysqlHandler):
def __init__(self, settings: dict[str, Any]):
super().__init__("MySQL")
@@ -78,3 +79,13 @@ class MySQL(UnstractDB):
datetime.datetime: "TIMESTAMP",
}
return mapping.get(python_type, "LONGTEXT")
def execute_query(
self, engine: Any, sql_query: str, sql_values: Any, **kwargs: Any
) -> None:
MysqlHandler.execute_query(
engine=engine,
sql_query=sql_query,
sql_values=sql_values,
database=self.database,
)

View File

@@ -0,0 +1,29 @@
import logging
from typing import Any
import pymysql.err as MysqlError
from unstract.connectors.databases.exceptions import InvalidSyntaxException
from unstract.connectors.databases.exceptions_helper import ExceptionHelper
logger = logging.getLogger(__name__)
class MysqlHandler:
@staticmethod
def execute_query(
engine: Any, sql_query: str, sql_values: Any, database: Any
) -> None:
try:
with engine.cursor() as cursor:
if sql_values:
cursor.execute(sql_query, sql_values)
else:
cursor.execute(sql_query)
engine.commit()
except MysqlError.ProgrammingError as e:
error_details = ExceptionHelper.extract_byte_exception(e=e)
logger.error(
f"Invalid syntax in creating/inserting mysql data: {error_details}"
)
raise InvalidSyntaxException(detail=error_details, database=database) from e

View File

@@ -4,10 +4,11 @@ from typing import Any
import psycopg2
from psycopg2.extensions import connection
from unstract.connectors.databases.psycopg_handler import PsycoPgHandler
from unstract.connectors.databases.unstract_db import UnstractDB
class PostgreSQL(UnstractDB):
class PostgreSQL(UnstractDB, PsycoPgHandler):
def __init__(self, settings: dict[str, Any]):
super().__init__("PostgreSQL")
@@ -71,3 +72,13 @@ class PostgreSQL(UnstractDB):
options=f"-c search_path={self.schema}",
)
return con
def execute_query(
self, engine: Any, sql_query: str, sql_values: Any, **kwargs: Any
) -> None:
PsycoPgHandler.execute_query(
engine=engine,
sql_query=sql_query,
sql_values=sql_values,
database=self.database,
)

View File

@@ -0,0 +1,50 @@
import logging
from typing import Any
from psycopg2 import errors as PsycopgError
from unstract.connectors.databases.exceptions import (
FeatureNotSupportedException,
InvalidSchemaException,
InvalidSyntaxException,
UnderfinedTableException,
ValueTooLongException,
)
logger = logging.getLogger(__name__)
class PsycoPgHandler:
@staticmethod
def execute_query(
engine: Any, sql_query: str, sql_values: Any, database: Any
) -> None:
try:
with engine.cursor() as cursor:
if sql_values:
cursor.execute(sql_query, sql_values)
else:
cursor.execute(sql_query)
engine.commit()
except PsycopgError.InvalidSchemaName as e:
logger.error(f"Invalid schema in creating table: {e.pgerror}")
raise InvalidSchemaException(detail=e.pgerror, database=database) from e
except PsycopgError.UndefinedTable as e:
logger.error(f"Undefined table in inserting: {e.pgerror}")
raise UnderfinedTableException(detail=e.pgerror, database=database) from e
except PsycopgError.SyntaxError as e:
logger.error(f"Invalid syntax in creating/inserting data: {e.pgerror}")
raise InvalidSyntaxException(detail=e.pgerror, database=database) from e
except PsycopgError.FeatureNotSupported as e:
logger.error(
f"feature not supported in creating/inserting data: {e.pgerror}"
)
raise FeatureNotSupportedException(
detail=e.pgerror, database=database
) from e
except (
PsycopgError.StringDataRightTruncation,
PsycopgError.InternalError_,
) as e:
logger.error(f"value too long for datatype: {e.pgerror}")
raise ValueTooLongException(detail=e.pgerror, database=database) from e

View File

@@ -5,10 +5,11 @@ from typing import Any
import psycopg2
from psycopg2.extensions import connection
from unstract.connectors.databases.psycopg_handler import PsycoPgHandler
from unstract.connectors.databases.unstract_db import UnstractDB
class Redshift(UnstractDB):
class Redshift(UnstractDB, PsycoPgHandler):
def __init__(self, settings: dict[str, Any]):
super().__init__("Redshift")
@@ -74,14 +75,13 @@ class Redshift(UnstractDB):
str: _description_
"""
python_type = type(value)
mapping = {
str: "VARCHAR(65535)",
str: "SUPER",
int: "BIGINT",
float: "DOUBLE PRECISION",
datetime.datetime: "TIMESTAMP",
}
return mapping.get(python_type, "VARCHAR(65535)")
return mapping.get(python_type, "SUPER")
@staticmethod
def get_create_table_query(table: str) -> str:
@@ -91,3 +91,13 @@ class Redshift(UnstractDB):
f"created_by VARCHAR(65535), created_at TIMESTAMP, "
)
return sql_query
def execute_query(
self, engine: Any, sql_query: str, sql_values: Any, **kwargs: Any
) -> None:
PsycoPgHandler.execute_query(
engine=engine,
sql_query=sql_query,
sql_values=sql_values,
database=self.database,
)

View File

@@ -1,11 +1,16 @@
import logging
import os
from typing import Any
import snowflake.connector
import snowflake.connector.errors as SnowflakeError
from snowflake.connector.connection import SnowflakeConnection
from unstract.connectors.databases.exceptions import SnowflakeProgrammingException
from unstract.connectors.databases.unstract_db import UnstractDB
logger = logging.getLogger(__name__)
class SnowflakeDB(UnstractDB):
def __init__(self, settings: dict[str, Any]):
@@ -70,3 +75,22 @@ class SnowflakeDB(UnstractDB):
f"created_by TEXT, created_at TIMESTAMP, "
)
return sql_query
def execute_query(
self, engine: Any, sql_query: str, sql_values: Any, **kwargs: Any
) -> None:
try:
with engine.cursor() as cursor:
if sql_values:
cursor.execute(sql_query, sql_values)
else:
cursor.execute(sql_query)
engine.commit()
except SnowflakeError.ProgrammingError as e:
logger.error(
f"snowflake programming error in crearing/inserting table: "
f"{e.msg} {e.errno}"
)
raise SnowflakeProgrammingException(
detail=e.msg, database=self.database
) from e

View File

@@ -68,7 +68,7 @@ class UnstractDB(UnstractConnector, ABC):
try:
self.get_engine()
except Exception as e:
raise ConnectorError(str(e))
raise ConnectorError(str(e)) from e
return True
def execute(self, query: str) -> Any:
@@ -77,7 +77,7 @@ class UnstractDB(UnstractConnector, ABC):
cursor.execute(query)
return cursor.fetchall()
except Exception as e:
raise ConnectorError(str(e))
raise ConnectorError(str(e)) from e
@staticmethod
def sql_to_db_mapping(value: str) -> str:
@@ -107,3 +107,9 @@ class UnstractDB(UnstractConnector, ABC):
f"created_by TEXT, created_at TIMESTAMP, "
)
return sql_query
@abstractmethod
def execute_query(
self, engine: Any, sql_query: str, sql_values: Any, **kwargs: Any
) -> None:
pass