Source code for pgcom.base

__all__ = [
    "BaseConnector",
    "BaseCommuter",
]

from contextlib import contextmanager
from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union

import abc

import numpy as np
import psycopg2
from psycopg2 import sql
from psycopg2.extensions import register_adapter, AsIs
from psycopg2.extras import execute_batch

from . import exc

QueryParams = Union[Sequence[Any], Mapping[str, Any]]
register_adapter(np.int64, AsIs)
register_adapter(np.float64, AsIs)


[docs]class BaseConnector(abc.ABC): """Base class for all connectors.""" def __init__(self, **kwargs: str) -> None: self._kwargs = kwargs if "db_name" in self._kwargs: self._kwargs["dbname"] = self._kwargs.pop("db_name") def __del__(self) -> None: self.close_all() def __repr__(self) -> str: desc = "(" for key in ["host", "user", "dbname"]: if key in self._kwargs.keys(): desc += key + "=" + self._kwargs[key] + " " desc += ")" return desc
[docs] @abc.abstractmethod @contextmanager def open_connection(self) -> Iterator[psycopg2.connect]: """Generates a new connection.""" raise NotImplementedError
[docs] @abc.abstractmethod def close_all(self) -> None: """Close all active connections.""" raise NotImplementedError
[docs]class BaseCommuter: """Base class for all commuters. Args: connector: Instance of connection handler, any subclass inherited from :class:`~pgcom.base.BaseConnector`. """ def __init__(self, connector: BaseConnector) -> None: self.connector = connector def __repr__(self) -> str: return repr(self.connector)
[docs] def execute( self, cmd: Union[str, sql.Composed], values: Optional[QueryParams] = None ) -> None: """Execute a database operation (query or command). Args: cmd: SQL query to be executed. values: Query parameters. Returns: List of rows of a query result and list of column names. Two empty lists are returned if there is no records to fetch. Raises: QueryExecutionError: if execution fails. """ self._execute(cmd=cmd, values=values)
def _execute( self, cmd: Union[str, sql.Composed], values: Optional[QueryParams] = None, commit: Optional[bool] = True, batch: Optional[bool] = False, ) -> Tuple[List[Any], List[str]]: """Execute a database operation, query or command. Args: cmd: SQL command. values: Query parameters. commit: Commit the results if True. batch: Use execute_batch method if True. Returns: List of rows of a query result and list of column names. Two empty lists are returned if there is no records to fetch. Raises: QueryExecutionError: if execution fails. """ fetched = [] columns = [] with self.connector.open_connection() as conn: try: with conn.cursor() as cur: if batch: execute_batch(cur, cmd, values) else: if values is None: cur.execute(cmd) else: cur.execute(cmd, values) if cur.description is not None: fetched = cur.fetchall() columns = [desc[0] for desc in cur.description] if commit: conn.commit() except Exception as e: try: conn.rollback() except Exception as ex: exc.raise_with_traceback( exc.QueryExecutionError( f"Execution failed on sql: {cmd}\n{ex}\n " f"unable to rollback" ) ) exc.raise_with_traceback( exc.QueryExecutionError(f"Execution failed on sql: {cmd}\n{e}\n") ) return fetched, columns def _get_schema(self, table_name: str) -> Tuple[str, str]: """Return schema and table name.""" names = str.split(table_name, ".") if len(names) == 2: return names[0], names[1] else: return "public", table_name