Source code for pgcom.commuter

__all__ = ["Commuter"]

from io import StringIO
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union

import numpy as np
import pandas as pd
from psycopg2 import sql

from . import exc, queries
from .base import BaseCommuter
from .connector import Connector

QueryParams = Union[Sequence[Any], Mapping[str, Any]]


[docs]class Commuter(BaseCommuter): """Communication agent. When creating a new instance of Commuter, the connection pool is created and the connection is established. The typical usage of Commuter is therefore once per particular database, held globally for the lifetime of a single application process. Args: pool_size: The maximum amount of connections the pool will support. pre_ping: If True, the pool will emit a "ping" on the connection to test if the connection is alive. If not, the connection will be reconnected. max_reconnects: The maximum amount of reconnects, defaults to 3. """ connector: Connector def __init__( self, pool_size: int = 20, pre_ping: bool = False, max_reconnects: int = 3, **kwargs: str, ) -> None: super().__init__(Connector(pool_size, pre_ping, max_reconnects, **kwargs)) def __repr__(self) -> str: return repr(self.connector)
[docs] def select( self, cmd: Union[str, sql.Composed], values: Optional[QueryParams] = None ) -> pd.DataFrame: """Read SQL query into a DataFrame. Returns a DataFrame corresponding to the result of the query. Args: cmd: string SQL query to be executed. values: Parameters to pass to execute method. Returns: Pandas.DataFrame. """ records, columns = self._execute(cmd, values=values) df = pd.DataFrame.from_records(records, columns=columns) return df
[docs] def select_one( self, cmd: Union[str, sql.Composed], values: Optional[QueryParams] = None, default: Optional[Any] = None, ) -> Any: """Select the first element of returned DataFrame. Args: cmd: string SQL query to be executed. values: Parameters to pass to execute method. default: If query result is empty, then return the default value. """ fetched, _ = self._execute(cmd, values=values) try: value = fetched[0][0] if value is None: value = default except IndexError: value = default return value
[docs] def execute_script(self, path2script: str) -> None: """Execute query from file. Args: path2script: Path to the file with the query. """ with open(path2script, "r") as fh: cmd = fh.read() self._execute(cmd=cmd)
[docs] def insert( self, table_name: str, data: pd.DataFrame, columns: Optional[List[str]] = None, placeholders: Optional[List[str]] = None, ) -> None: """Write rows from a DataFrame to a database table. Args: table_name: Name of the destination table. data: Pandas.DataFrame with the data to be inserted. columns: List of column names used for insert. If not specified then all the columns are used. Defaults to None. placeholders: List of placeholders. If not specified then the default placeholders are used. Defaults to None. Examples: .. code:: >>> self.insert("people", data) Insert two columns, name and age. .. code:: >>> self.insert("people", data, columns=["name", "age"]) You can customize placeholders to implement advanced insert, e.g. to insert geometry data in a database with PostGIS extension. .. code:: >>> self.insert( ... table_name="polygons", ... data=data, ... columns=["name", "geom"], ... placeholders=["%s", "ST_GeomFromText(%s, 4326)"]) """ if columns is None: columns = list(data.columns) if placeholders is None: placeholders = sql.Placeholder() * len(columns) else: placeholders = sql.Composed([sql.SQL(p) for p in placeholders]) cmd = sql.SQL("INSERT INTO {} ({}) VALUES ({})").format( sql.SQL(table_name), sql.SQL(", ").join(map(sql.Identifier, columns)), sql.SQL(", ").join(placeholders), ) rows = data[columns].to_numpy(na_value=None) for values in [tuple(row) for row in rows]: self._execute(cmd=cmd, values=values)
[docs] def insert_row( self, table_name: str, return_id: Optional[str] = None, **kwargs: Any ) -> Optional[int]: """Implements insert command. Inserted values are passed through the keyword arguments. Args: table_name: Name of the destination table. return_id: Name of the returned serial key. """ sid = None keys = list(kwargs.keys()) cmd = sql.SQL("INSERT INTO {} ({}) VALUES ({})").format( sql.SQL(table_name), sql.SQL(", ").join(map(sql.Identifier, keys)), sql.SQL(", ").join(map(sql.Placeholder, keys)), ) if return_id is not None: sid = self.insert_return(cmd, return_id=return_id, values=kwargs) else: self._execute(cmd=cmd, values=kwargs) return sid
[docs] def insert_return( self, cmd: Union[str, sql.Composed], values: Optional[QueryParams] = None, return_id: Optional[str] = None, ) -> int: """Insert a new row to the table and return the serial key of the newly inserted row. Args: cmd: INSERT INTO command. values: Collection of values to be inserted. return_id: Name of the returned serial key. """ if return_id is not None: cmd = sql.Composed( [ sql.SQL(cmd) if isinstance(cmd, str) else cmd, sql.SQL(" RETURNING {}").format(sql.Identifier(return_id)), ] ) fetched, _ = self._execute(cmd, values) try: sid = fetched[0][0] except IndexError: sid = 0 return sid
[docs] def copy_from( self, table_name: str, data: pd.DataFrame, format_data: bool = False, sep: str = ",", na_value: str = "", where: Optional[Union[str, sql.Composed]] = None, ) -> None: """Places DataFrame to a buffer and apply COPY FROM command. Args: table_name: Name of the table where to insert. data DataFrame from where to insert. format_data: Reorder columns and adjust dtypes wrt to table metadata from information_schema. sep: String of length 1. Field delimiter for the output file. Defaults to ",". na_value: Missing data representation, defaults to "". where: WHERE clause used to specify a condition while deleting data from the table before applying copy_from, DELETE command is not executed if not specified. Raises: CopyError: if execution fails. """ df = data.copy() if format_data: df = self._format_data(df, table_name, sep=sep) with self.connector.open_connection() as conn: try: with conn.cursor() as cur: if where is not None: if isinstance(where, str): where = sql.SQL(where) cmd = sql.Composed( [ sql.SQL("DELETE FROM {} WHERE ").format( sql.SQL(table_name) ), where, ] ) cur.execute(cmd) # DataFrame to buffer s_buf = StringIO() df.to_csv( path_or_buf=s_buf, sep=sep, na_rep=na_value, index=False, header=False, ) s_buf.seek(0) # copy from buffer columns = ", ".join(df.columns) cmd = ( f"COPY {table_name} ({columns}) FROM STDOUT " f"DELIMITER '{sep}' NULL '{na_value}'" ) cur.copy_expert(cmd, s_buf) conn.commit() except Exception as e: try: conn.rollback() except Exception as ex: exc.raise_with_traceback( exc.CopyError(f"{ex}\n unable to rollback") ) exc.raise_with_traceback(exc.CopyError(f"{e}\n"))
[docs] def is_table_exist(self, table_name: str) -> bool: """Return True if table exists, otherwise False. Args: table_name: Name of the table where to insert. """ _schema, _table_name = self._get_schema(table_name) df = self.select(queries.is_table_exist(_table_name, _schema)) return bool(len(df) > 0)
[docs] def is_entry_exist(self, table_name: str, **kwargs: Any) -> bool: """Return True if entry already exists, otherwise return False. Implements a simple query to verify if a specific entry exists in the table. WHERE clause is created from ``**kwargs``. Args: table_name: Name of the database table. **kwargs: Parameters to create WHERE clause. Examples: Implement query ``SELECT 1 FROM people WHERE id=5 AND num=100``. .. code:: >>> self.is_entry_exist("my_table", id=5, num=100) True """ cmd = sql.SQL("SELECT 1 FROM {} WHERE {}").format( sql.SQL(table_name), self.make_where(list(kwargs.keys())) ) res = self.select_one(cmd=cmd, values=kwargs, default=None) return res is not None
[docs] def delete_entry(self, table_name: str, **kwargs: Any) -> None: """Delete entry from the table. Implements a simple query to delete a specific entry from the table. WHERE clause is created from ``**kwargs``. Args: table_name: Name of the database table. **kwargs: Parameters to create WHERE clause. Examples: Delete rows with version=100 from the table. .. code:: >>> self.delete_entry("dict_versions", version=100) """ cmd = sql.SQL("DELETE FROM {} WHERE {}").format( sql.SQL(table_name), self.make_where(list(kwargs.keys())) ) self._execute(cmd, values=kwargs)
[docs] @staticmethod def make_where(keys: List[str]) -> sql.Composed: """Build WHERE clause from list of keys. Examples: .. code:: >>> self.make_where(["version", "task"]) "version=%s AND task=%s" """ where = list() # type: List[Union[sql.Composable]] for key in keys: if len(where) > 0: where += [sql.SQL(" AND ")] where += [sql.Identifier(key), sql.SQL("="), sql.Placeholder(key)] return sql.Composed(where)
[docs] def get_connections_count(self) -> int: """Returns the amount of active connections.""" return self.select_one(cmd=queries.conn_count(), default=0)
[docs] def resolve_primary_conflicts( self, table_name: str, data: pd.DataFrame, where: Optional[Union[str, sql.Composed]] = None, ) -> pd.DataFrame: """Resolve primary key conflicts in DataFrame. Remove all the rows from the DataFrame conflicted with primary key constraint. Parameter ``where`` is used to reduce the amount of querying data. Args: table_name: Name of the table. data: DataFrame where the primary key conflicts need to be resolved. where: WHERE clause used when querying data from the ``table_name``. Returns: DataFrame without primary key conflicts. """ p_key = self.select(queries.primary_key(table_name)) p_key = p_key["column_name"].to_list() df = data.copy() if len(p_key) > 0: if where is not None: if isinstance(where, str): where = sql.SQL(where) cmd = sql.Composed( [ sql.SQL("SELECT * FROM {} WHERE ").format(sql.SQL(table_name)), where, ] ) else: cmd = sql.SQL("SELECT * FROM {}").format(sql.SQL(table_name)) table_data = self.select(cmd) if not table_data.empty: df.set_index(p_key, inplace=True) table_data.set_index(p_key, inplace=True) # remove rows which are in table data index df = df[~df.index.isin(table_data.index)] # reset index and sort columns df = df.reset_index(level=p_key) df = df[data.columns] return df
[docs] def resolve_foreign_conflicts( self, table_name: str, parent_name: str, data: pd.DataFrame, where: Optional[Union[str, sql.Composed]] = None, ) -> pd.DataFrame: """Resolve foreign key conflicts in DataFrame. Remove all the rows from the DataFrame conflicted with foreign key constraint. Parameter ``where`` is used to reduce the amount of querying data. Args: table_name: Name of the child table, where the data needs to be inserted. parent_name: Name of the parent table. data: DataFrame with foreign key conflicts. where: WHERE clause used when querying from the ``table_name``. Returns: DataFrame without foreign key conflicts. """ df = data.copy() _schema, _table_name = self._get_schema(table_name) _parent_schema, _parent_name = self._get_schema(parent_name) foreign_key = self.select( queries.foreign_key(_table_name, _schema, _parent_name, _parent_schema) ) if len(foreign_key) > 0: if where is not None: if isinstance(where, str): where = sql.SQL(where) cmd = sql.Composed( [ sql.SQL("SELECT * FROM {} WHERE ").format(sql.SQL(parent_name)), where, ] ) else: cmd = sql.SQL("SELECT * FROM {}").format(sql.SQL(parent_name)) parent_data = self.select(cmd) if not parent_data.empty: df.set_index(foreign_key["child_column"].to_list(), inplace=True) parent_data.set_index( foreign_key["parent_column"].to_list(), inplace=True ) # remove rows which are not in parent index df = df[df.index.isin(parent_data.index)] # reset index and sort columns df = df.reset_index(level=foreign_key["child_column"].to_list()) df = df[data.columns] else: df = pd.DataFrame() return df
[docs] def encode_category( self, data: pd.DataFrame, category: str, key: str, category_table: str, category_name: Optional[str] = None, key_name: Optional[str] = None, na_value: Optional[str] = None, ) -> pd.DataFrame: """Encode categorical column. Implements writing of all the unique values in categorical column given by ``category_name`` to the table given by ``category_table``. Replaces all the values in ``category`` column in the original DataFrame with the corresponding integer values assigned to categories via serial primary key constraint. Args: data: Pandas.DataFrame with categorical column. category: Name of the categorical column in DataFrame the method is applied for. key: Name of the DataFrame column with encoded values. category_table: Name of the table with stored categories. category_name: Name of the categorical column in ``category_table``. Defaults to ``category``. key_name: Name of the column in ``category_table`` contained the encoded values. Defaults to ``key``. na_value: Missing data representation. Returns: Pandas.DataFrame with encoded category. """ if category_name is None: category_name = category if key_name is None: key_name = key if na_value is not None: data[category] = data[category].fillna(na_value) data[category] = data[category].str.replace(",", "") cat = data[[category]].drop_duplicates() cat.rename(columns={category: category_name}, inplace=True) table_data = self.select( sql.SQL("SELECT DISTINCT {} FROM {}").format( sql.SQL(category_name), sql.SQL(category_table) ) ) if not table_data.empty: cat = cat[~cat[category_name].isin(table_data[category_name].tolist())] if len(cat) > 0: self.copy_from(category_table, cat, format_data=True) cmd = sql.SQL("SELECT {} AS {}, {} AS {} FROM {}").format( sql.Identifier(key_name), sql.Identifier(key), sql.Identifier(category_name), sql.Identifier(category), sql.SQL(category_table), ) df = self.select(cmd) data[key] = data[category].map(df.set_index(category)[key].to_dict()) return data
[docs] def encode_composite_category( self, data: pd.DataFrame, categories: Dict[str, str], key: str, category_table: str, key_name: Optional[str] = None, na_value: Optional[str] = None, ) -> pd.DataFrame: """Encode categories represented by multiple columns. Implements writing of all the unique combinations given by multiple columns in DataFrame to the table given by ``category_table``. Dictionary ``categories`` provides a mapping between DataFrame and ``category_table`` column names. Args: data: Pandas.DataFrame with categorical columns. categories: Dictionary provided the mapping between column names. Dict keys provide names of columns in ``data`` represented category, values represent column names in ``category_table``. key: Name of the DataFrame column with encoded values. category_table: Name of the table with stored categories. key_name: Name of the column in ``category_table`` contained the encoded values. Defaults to ``key``. na_value: Missing data representation. Returns: Pandas.DataFrame with encoded category. """ if key_name is None: key_name = key for category in categories.keys(): if na_value is not None: data[category] = data[category].fillna(na_value) if data[category].dtype == object and isinstance( data[category].iloc[0], str ): data[category] = data[category].str.replace(",", "") cat = data.drop_duplicates(subset=list(categories.keys())) cat.rename(columns=categories, inplace=True) table_data = self.select( sql.SQL("SELECT * FROM {}").format(sql.SQL(category_table)) ) composite_key = list(categories.values()) if not table_data.empty: cat.set_index(composite_key, inplace=True) table_data.set_index(composite_key, inplace=True) cat = cat[~cat.index.isin(table_data.index)] cat.reset_index(level=composite_key, inplace=True) if len(cat) > 0: self.copy_from(category_table, cat, format_data=True) df = self.select(sql.SQL("SELECT * FROM {}").format(sql.SQL(category_table))) df.rename(columns={v: k for k, v in categories.items()}, inplace=True) df.rename(columns={key_name: key}, inplace=True) columns = list(categories.keys()) + [key] data = data.merge(df[columns], how="inner", on=list(categories.keys())) return data
def _table_columns(self, table_name: str) -> pd.DataFrame: """Return columns attributes of the given table. Args: table_name: Name of the table. Returns: Pandas.DataFrame with the names and data types of all the columns of the given table. """ _schema, _table_name = self._get_schema(table_name) return self.select(queries.column_names(_table_name, _schema)) def _format_data( self, data: pd.DataFrame, table_name: str, sep: str = "," ) -> pd.DataFrame: """Formatting DataFrame before applying COPY FROM.""" table_columns = self._table_columns(table_name) columns = [] # type: List[str] for row in table_columns.itertuples(): column = row.column_name if column in data.columns: columns += [column] if row.data_type in ["smallint", "integer", "bigint"]: if data[column].dtype == np.float64: data[column] = data[column].round().astype("Int64") elif row.data_type in ["text"]: try: data[column] = data[column].str.replace(sep, "") except AttributeError: continue return data[columns]