"""``SQLDataSet`` to load and save data to a SQL backend."""
import copy
import re
from pathlib import PurePosixPath
from typing import Any, Dict, Optional
import fsspec
import pandas as pd
from sqlalchemy import create_engine
from sqlalchemy.exc import NoSuchModuleError
from kedro.io.core import (
AbstractDataSet,
DataSetError,
get_filepath_str,
get_protocol_and_path,
)
__all__ = ["SQLTableDataSet", "SQLQueryDataSet"]
KNOWN_PIP_INSTALL = {
"psycopg2": "psycopg2",
"mysqldb": "mysqlclient",
"cx_Oracle": "cx_Oracle",
}
DRIVER_ERROR_MESSAGE = """
A module/driver is missing when connecting to your SQL server. SQLDataSet
supports SQLAlchemy drivers. Please refer to
https://docs.sqlalchemy.org/en/13/core/engines.html#supported-databases
for more information.
\n\n
"""
def _find_known_drivers(module_import_error: ImportError) -> Optional[str]:
"""Looks up known keywords in a ``ModuleNotFoundError`` so that it can
provide better guideline for the user.
Args:
module_import_error: Error raised while connecting to a SQL server.
Returns:
Instructions for installing missing driver. An empty string is
returned in case error is related to an unknown driver.
"""
# module errors contain string "No module name 'module_name'"
# we are trying to extract module_name surrounded by quotes here
res = re.findall(r"'(.*?)'", str(module_import_error.args[0]).lower())
# in case module import error does not match our expected pattern
# we have no recommendation
if not res:
return None
missing_module = res[0]
if KNOWN_PIP_INSTALL.get(missing_module):
return (
f"You can also try installing missing driver with\n"
f"\npip install {KNOWN_PIP_INSTALL.get(missing_module)}"
)
return None
def _get_missing_module_error(import_error: ImportError) -> DataSetError:
missing_module_instruction = _find_known_drivers(import_error)
if missing_module_instruction is None:
return DataSetError(
f"{DRIVER_ERROR_MESSAGE}Loading failed with error:\n\n{str(import_error)}"
)
return DataSetError(f"{DRIVER_ERROR_MESSAGE}{missing_module_instruction}")
def _get_sql_alchemy_missing_error() -> DataSetError:
return DataSetError(
"The SQL dialect in your connection is not supported by "
"SQLAlchemy. Please refer to "
"https://docs.sqlalchemy.org/en/13/core/engines.html#supported-databases "
"for more information."
)
[docs]class SQLTableDataSet(AbstractDataSet):
"""``SQLTableDataSet`` loads data from a SQL table and saves a pandas
dataframe to a table. It uses ``pandas.DataFrame`` internally,
so it supports all allowed pandas options on ``read_sql_table`` and
``to_sql`` methods. Since Pandas uses SQLAlchemy behind the scenes, when
instantiating ``SQLTableDataSet`` one needs to pass a compatible connection
string either in ``credentials`` (see the example code snippet below) or in
``load_args`` and ``save_args``. Connection string formats supported by
SQLAlchemy can be found here:
https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls
``SQLTableDataSet`` modifies the save parameters and stores
the data with no index. This is designed to make load and save methods
symmetric.
Example adding a catalog entry with
`YAML API <https://kedro.readthedocs.io/en/stable/05_data/\
01_data_catalog.html#using-the-data-catalog-with-the-yaml-api>`_:
.. code-block:: yaml
>>> shuttles_table_dataset:
>>> type: pandas.SQLTableDataSet
>>> credentials: db_credentials
>>> table_name: shuttles
>>> load_args:
>>> schema: dwschema
>>> save_args:
>>> schema: dwschema
>>> if_exists: replace
Sample database credentials entry in ``credentials.yml``:
.. code-block:: yaml
>>> db_creds:
>>> con: postgresql://scott:tiger@localhost/test
Example using Python API:
::
>>> from kedro.extras.datasets.pandas import SQLTableDataSet
>>> import pandas as pd
>>>
>>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5],
>>> "col3": [5, 6]})
>>> table_name = "table_a"
>>> credentials = {
>>> "con": "postgresql://scott:tiger@localhost/test"
>>> }
>>> data_set = SQLTableDataSet(table_name=table_name,
>>> credentials=credentials)
>>>
>>> data_set.save(data)
>>> reloaded = data_set.load()
>>>
>>> assert data.equals(reloaded)
"""
DEFAULT_LOAD_ARGS: Dict[str, Any] = {}
DEFAULT_SAVE_ARGS: Dict[str, Any] = {"index": False}
# using Any because of Sphinx but it should be
# sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine
engines: Dict[str, Any] = {}
[docs] def __init__(
self,
table_name: str,
credentials: Dict[str, Any],
load_args: Dict[str, Any] = None,
save_args: Dict[str, Any] = None,
) -> None:
"""Creates a new ``SQLTableDataSet``.
Args:
table_name: The table name to load or save data to. It
overwrites name in ``save_args`` and ``table_name``
parameters in ``load_args``.
credentials: A dictionary with a ``SQLAlchemy`` connection string.
Users are supposed to provide the connection string 'con'
through credentials. It overwrites `con` parameter in
``load_args`` and ``save_args`` in case it is provided. To find
all supported connection string formats, see here:
https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls
load_args: Provided to underlying pandas ``read_sql_table``
function along with the connection string.
To find all supported arguments, see here:
https://pandas.pydata.org/pandas-docs/stable/generated/pandas.read_sql_table.html
To find all supported connection string formats, see here:
https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls
save_args: Provided to underlying pandas ``to_sql`` function along
with the connection string.
To find all supported arguments, see here:
https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.to_sql.html
To find all supported connection string formats, see here:
https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls
It has ``index=False`` in the default parameters.
Raises:
DataSetError: When either ``table_name`` or ``con`` is empty.
"""
if not table_name:
raise DataSetError("`table_name` argument cannot be empty.")
if not (credentials and "con" in credentials and credentials["con"]):
raise DataSetError(
"`con` argument cannot be empty. Please "
"provide a SQLAlchemy connection string."
)
# Handle default load and save arguments
self._load_args = copy.deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = copy.deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)
self._load_args["table_name"] = table_name
self._save_args["name"] = table_name
self._connection_str = credentials["con"]
self.create_connection(self._connection_str)
[docs] @classmethod
def create_connection(cls, connection_str: str) -> None:
"""Given a connection string, create singleton connection
to be used across all instances of `SQLTableDataSet` that
need to connect to the same source.
"""
if connection_str in cls.engines:
return
try:
engine = create_engine(connection_str)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc
cls.engines[connection_str] = engine
def _describe(self) -> Dict[str, Any]:
load_args = copy.deepcopy(self._load_args)
save_args = copy.deepcopy(self._save_args)
del load_args["table_name"]
del save_args["name"]
return dict(
table_name=self._load_args["table_name"],
load_args=load_args,
save_args=save_args,
)
def _load(self) -> pd.DataFrame:
engine = self.engines[self._connection_str] # type:ignore
return pd.read_sql_table(con=engine, **self._load_args)
def _save(self, data: pd.DataFrame) -> None:
engine = self.engines[self._connection_str] # type: ignore
data.to_sql(con=engine, **self._save_args)
def _exists(self) -> bool:
eng = self.engines[self._connection_str] # type: ignore
schema = self._load_args.get("schema", None)
exists = self._load_args["table_name"] in eng.table_names(schema)
return exists
[docs]class SQLQueryDataSet(AbstractDataSet):
"""``SQLQueryDataSet`` loads data from a provided SQL query. It
uses ``pandas.DataFrame`` internally, so it supports all allowed
pandas options on ``read_sql_query``. Since Pandas uses SQLAlchemy behind
the scenes, when instantiating ``SQLQueryDataSet`` one needs to pass
a compatible connection string either in ``credentials`` (see the example
code snippet below) or in ``load_args``. Connection string formats supported
by SQLAlchemy can be found here:
https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls
It does not support save method so it is a read only data set.
To save data to a SQL server use ``SQLTableDataSet``.
Example adding a catalog entry with
`YAML API <https://kedro.readthedocs.io/en/stable/05_data/\
01_data_catalog.html#using-the-data-catalog-with-the-yaml-api>`_:
.. code-block:: yaml
>>> shuttle_id_dataset:
>>> type: pandas.SQLQueryDataSet
>>> sql: "select shuttle, shuttle_id from spaceflights.shuttles;"
>>> credentials: db_credentials
>>> layer: raw
Sample database credentials entry in ``credentials.yml``:
.. code-block:: yaml
>>> db_creds:
>>> con: postgresql://scott:tiger@localhost/test
Example using Python API:
::
>>> from kedro.extras.datasets.pandas import SQLQueryDataSet
>>> import pandas as pd
>>>
>>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5],
>>> "col3": [5, 6]})
>>> sql = "SELECT * FROM table_a"
>>> credentials = {
>>> "con": "postgresql://scott:tiger@localhost/test"
>>> }
>>> data_set = SQLQueryDataSet(sql=sql,
>>> credentials=credentials)
>>>
>>> sql_data = data_set.load()
>>>
"""
# using Any because of Sphinx but it should be
# sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine
engines: Dict[str, Any] = {}
[docs] def __init__( # pylint: disable=too-many-arguments
self,
sql: str = None,
credentials: Dict[str, Any] = None,
load_args: Dict[str, Any] = None,
fs_args: Dict[str, Any] = None,
filepath: str = None,
) -> None:
"""Creates a new ``SQLQueryDataSet``.
Args:
sql: The sql query statement.
credentials: A dictionary with a ``SQLAlchemy`` connection string.
Users are supposed to provide the connection string 'con'
through credentials. It overwrites `con` parameter in
``load_args`` and ``save_args`` in case it is provided. To find
all supported connection string formats, see here:
https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls
load_args: Provided to underlying pandas ``read_sql_query``
function along with the connection string.
To find all supported arguments, see here:
https://pandas.pydata.org/pandas-docs/stable/generated/pandas.read_sql_query.html
To find all supported connection string formats, see here:
https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls
fs_args: Extra arguments to pass into underlying filesystem class constructor
(e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as
to pass to the filesystem's `open` method through nested keys
`open_args_load` and `open_args_save`.
Here you can find all available arguments for `open`:
https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open
All defaults are preserved, except `mode`, which is set to `r` when loading.
filepath: A path to a file with a sql query statement.
Raises:
DataSetError: When either ``sql`` or ``con`` parameters is empty.
"""
if sql and filepath:
raise DataSetError(
"`sql` and `filepath` arguments cannot both be provided."
"Please only provide one."
)
if not (sql or filepath):
raise DataSetError(
"`sql` and `filepath` arguments cannot both be empty."
"Please provide a sql query or path to a sql query file."
)
if not (credentials and "con" in credentials and credentials["con"]):
raise DataSetError(
"`con` argument cannot be empty. Please "
"provide a SQLAlchemy connection string."
)
default_load_args = {} # type: Dict[str, Any]
self._load_args = (
{**default_load_args, **load_args}
if load_args is not None
else default_load_args
)
# load sql query from file
if sql:
self._load_args["sql"] = sql
self._filepath = None
else:
# filesystem for loading sql file
_fs_args = copy.deepcopy(fs_args) or {}
_fs_credentials = _fs_args.pop("credentials", {})
protocol, path = get_protocol_and_path(str(filepath))
self._protocol = protocol
self._fs = fsspec.filesystem(self._protocol, **_fs_credentials, **_fs_args)
self._filepath = path
self._connection_str = credentials["con"]
self.create_connection(self._connection_str)
[docs] @classmethod
def create_connection(cls, connection_str: str) -> None:
"""Given a connection string, create singleton connection
to be used across all instances of `SQLQueryDataSet` that
need to connect to the same source.
"""
if connection_str in cls.engines:
return
try:
engine = create_engine(connection_str)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc
cls.engines[connection_str] = engine
def _describe(self) -> Dict[str, Any]:
load_args = copy.deepcopy(self._load_args)
return dict(
sql=str(load_args.pop("sql", None)),
filepath=str(self._filepath),
load_args=str(load_args),
)
def _load(self) -> pd.DataFrame:
load_args = copy.deepcopy(self._load_args)
engine = self.engines[self._connection_str] # type: ignore
if self._filepath:
load_path = get_filepath_str(PurePosixPath(self._filepath), self._protocol)
with self._fs.open(load_path, mode="r") as fs_file:
load_args["sql"] = fs_file.read()
return pd.read_sql_query(con=engine, **load_args)
def _save(self, data: pd.DataFrame) -> None:
raise DataSetError("`save` is not supported on SQLQueryDataSet")