# ===============================================================================
# Copyright 2011 Jake Ross
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===============================================================================
# =============enthought library imports=======================
import os
from datetime import datetime, timedelta
from threading import Lock
import six
from sqlalchemy import create_engine, distinct, MetaData
from sqlalchemy.exc import (
SQLAlchemyError,
InvalidRequestError,
StatementError,
DBAPIError,
OperationalError,
)
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound
from traits.api import (
Password,
Bool,
Str,
on_trait_change,
Any,
Property,
cached_property,
Int,
)
from pychron.database.core.base_orm import AlembicVersionTable
from pychron.database.core.query import compile_query
from pychron.loggable import Loggable
from pychron.regex import IPREGEX
def obscure_host(h):
if IPREGEX.match(h):
h = "x.x.x.{}".format(h.split(".")[-1])
return h
def binfunc(ds, hours):
ds = [dx.timestamp for dx in ds]
p1 = ds[0]
delta_seconds = hours * 3600
td = timedelta(seconds=delta_seconds * 0.25)
for i, di in enumerate(ds):
i = max(0, i - 1)
dd = ds[i]
if (di - dd).total_seconds() > delta_seconds:
yield p1 - td, dd + td
p1 = di
yield p1 - td, di + td
[docs]class SessionCTX(object):
def __init__(self, parent, use_parent_session=True):
self._use_parent_session = use_parent_session
self._parent = parent
self._session = None
self._psession = None
def __enter__(self):
if self._use_parent_session:
self._parent.create_session()
return self._parent.session
else:
self._psession = self._parent.session
self._session = self._parent.session_factory()
self._parent.session = self._session
return self._session
def __exit__(self, exc_type, exc_val, exc_tb):
if self._session:
self._session.close()
else:
self._parent.close_session()
if self._psession:
self._parent.session = self._psession
self._psession = None
class MockQuery:
def join(self, *args, **kw):
return self
def filter(self, *args, **kw):
# type: (object, object) -> object
return self
def all(self, *args, **kw):
return []
def order_by(self, *args, **kw):
return self
class MockSession:
def query(self, *args, **kw):
return MockQuery()
# def __getattr__(self, item):
# return
[docs]class DatabaseAdapter(Loggable):
"""
The DatabaseAdapter is a base class for interacting with a SQLAlchemy database.
Two main subclasses are used by pychron, IsotopeAdapter and MassSpecDatabaseAdapter.
This class provides attributes for describing the database url, i.e host, user, password etc,
and methods for connecting and opening database sessions.
It also provides some helper functions used extensively by the subclasses, e.g. ``_add_item``,
``_retrieve_items``
"""
session = None
sess_stack = 0
reraise = False
connected = Bool(False)
kind = Str
prev_kind = Str
username = Str
host = Str
password = Password
timeout = Int
session_factory = None
application = Any
test_func = "get_versions"
version_func = "get_versions"
autoflush = True
autocommit = False
commit_on_add = True
# name used when writing to database
# save_username = Str
connection_parameters_changed = Bool
url = Property(depends_on="connection_parameters_changed")
datasource_url = Property(depends_on="connection_parameters_changed")
path = Str
echo = False
verbose_retrieve_query = False
verbose = True
connection_error = Str
_session_lock = None
modified = False
_trying_to_add = False
_test_connection_enabled = True
def __init__(self, *args, **kw):
super(DatabaseAdapter, self).__init__(*args, **kw)
self._session_lock = Lock()
[docs] def create_all(self, metadata):
"""
Build a database schema with the current connection
:param metadata: SQLAchemy MetaData object
"""
# if self.kind == 'sqlite':
metadata.create_all(self.session.bind)
# def session_ctx(self, sess=None, commit=True, rollback=True):
# """
# Make a new session context.
#
# :return: ``SessionCTX``
# """
# with self._session_lock:
# if sess is None:
# sess = self.sess
# return SessionCTX(sess, parent=self, commit=commit, rollback=rollback)
_session_cnt = 0
def session_ctx(self, use_parent_session=True):
with self._session_lock:
return SessionCTX(self, use_parent_session)
def create_session(self, force=False):
if self.connect(test=False):
if self.session_factory:
if force:
self.debug("force create new session {}".format(id(self)))
if self.session:
self.session.close()
self.session = self.session_factory()
self._session_cnt = 1
else:
if not self.session:
# self.debug('create new session {}'.format(id(self)))
self.session = self.session_factory()
self._session_cnt += 1
else:
self.warning("no session factory")
else:
self.session = MockSession()
def close_session(self):
if self.session and not isinstance(self.session, MockSession):
self.session.flush()
self._session_cnt -= 1
if not self._session_cnt:
self.debug("close session {}".format(id(self)))
self.session.close()
self.session = None
@property
def enabled(self):
return self.kind in ["mysql", "sqlite", "postgresql", "mssql"]
@property
def save_username(self):
from pychron.globals import globalv
return globalv.username
@on_trait_change("username,host,password,name,kind,path")
def reset_connection(self):
"""
Trip the ``connection_parameters_changed`` flag. Next ``connect`` call with use the new values
"""
self.connection_parameters_changed = True
self.session_factory = None
self.session = None
# @caller
[docs] def connect(
self, test=True, force=False, warn=True, version_warn=True, attribute_warn=False
):
"""
Connect to the database
:param test: Test the connection by running ``test_func``
:param force: Test connection even if connection parameters haven't changed
:param warn: Warn if the connection test fails
:param version_warn: Warn if database/pychron versions don't match
:return: True if connected else False
:rtype: bool
"""
self.connection_error = ""
if force:
self.debug("forcing database connection")
if self.connection_parameters_changed:
self._test_connection_enabled = True
force = True
if not self.connected or force:
# self.connected = True if self.kind == 'sqlite' else False
self.connected = False
pool_recycle = 600
if self.kind == "sqlite":
self.connected = True
test = False
pool_recycle = -1
self.connection_error = (
'Database "{}" kind not set. '
'Set in Preferences. current kind="{}"'.format(self.name, self.kind)
)
if not self.enabled:
from pychron.core.ui.gui import invoke_in_main_thread
invoke_in_main_thread(self.warning_dialog, self.connection_error)
else:
url = self.url
if url is not None:
self.info(
"{} connecting to database {}".format(id(self), self.public_url)
)
engine = create_engine(
url, echo=self.echo, pool_recycle=pool_recycle
)
self.session_factory = sessionmaker(
bind=engine,
autoflush=self.autoflush,
expire_on_commit=False,
autocommit=self.autocommit,
)
if test:
if not self._test_connection_enabled:
warn = False
else:
if self.test_func:
self.connected = self._test_db_connection(version_warn)
else:
self.connected = True
else:
self.connected = True
if self.connected:
self.info("connected to db {}".format(self.public_url))
# self.initialize_database()
else:
self.connection_error = 'Not Connected to Database "{}".\nAccess Denied for user= {} \
host= {}\nurl= {}'.format(
self.name, self.username, self.host, self.public_url
)
if warn:
from pychron.core.ui.gui import invoke_in_main_thread
invoke_in_main_thread(
self.warning_dialog, self.connection_error
)
self.connection_parameters_changed = False
return self.connected
# def initialize_database(self):
# pass
def rollback(self):
if self.session:
self.session.rollback()
[docs] def flush(self):
"""
flush the session
"""
if self.session:
try:
self.session.flush()
except:
self.session.rollback()
def expire(self, i):
if self.session:
self.session.expire(i)
def expire_all(self):
if self.session:
self.session.expire_all()
[docs] def commit(self):
"""
commit the session
"""
if self.session:
try:
self.session.commit()
except BaseException as e:
self.warning("Commit exception: {}".format(e))
self.session.rollback()
def delete(self, obj):
if self.session:
self.session.delete(obj)
def post_commit(self):
if self._trying_to_add:
self.modified = True
def add_item(self, *args, **kw):
return self._add_item(*args, **kw)
# def get_session(self):
# """
# return the current session or make a new one
#
# :return: Session
# """
# sess = self.sess
# if sess is None:
# self.debug('$$$$$$$$$$$$$$$$ session is None')
# sess = self.session_factory()
#
# return sess
[docs] def get_migrate_version(self, **kw):
"""
Query the AlembicVersionTable
"""
q = self.session.query(AlembicVersionTable)
mv = q.one()
return mv
def get_versions(self, **kw):
pass
@property
def public_datasource_url(self):
if self.kind == "sqlite":
url = "{}:{}".format(
os.path.basename(os.path.dirname(self.path)),
os.path.basename(self.path),
)
else:
url = "{}:{}".format(obscure_host(self.host), self.name)
return url
@cached_property
def _get_datasource_url(self):
if self.kind == "sqlite":
url = "{}:{}".format(
os.path.basename(os.path.dirname(self.path)),
os.path.basename(self.path),
)
else:
url = "{}:{}".format(self.host, self.name)
return url
@property
def public_url(self):
kind = self.kind
user = self.username
host = self.host
name = self.name
if kind == "sqlite":
pu = "{}:{}".format(
os.path.basename(os.path.dirname(self.path)),
os.path.basename(self.path),
)
else:
pu = "{}://{}@{}/{}".format(kind, user, host, name)
return pu
@cached_property
def _get_url(self):
kind = self.kind
password = self.password
user = self.username
host = self.host
name = self.name
timeout = self.timeout
if kind in ("mysql", "postgresql", "mssql"):
if kind == "mysql":
# add support for different mysql drivers
driver = self._import_mysql_driver()
if driver is None:
return
elif kind == "mssql":
driver = self._import_mssql_driver()
if driver is None:
return
else:
driver = "pg8000"
if password:
user = "{}:{}".format(user, password)
prefix = "{}+{}://{}@".format(kind, driver, user)
if driver == "pyodbc":
url = "{}{}".format(prefix, name)
else:
url = "{}{}/{}".format(prefix, host, name)
if kind == "mysql" and self.timeout:
url = "{}?connect_timeout={}".format(url, timeout)
else:
url = "sqlite:///{}".format(self.path)
return url
def _import_mssql_driver(self):
driver = None
try:
import pyodbc
driver = "pyodbc"
except ImportError:
try:
import pymssql
driver = "pymssql"
except ImportError:
pass
self.info('using mssql driver="{}"'.format(driver))
return driver
def _import_mysql_driver(self):
try:
"""
pymysql
https://github.com/petehunt/PyMySQL/
"""
import pymysql
driver = "pymysql"
except ImportError:
try:
import _mysql
driver = "mysqldb"
except ImportError:
self.warning_dialog(
"A mysql driver was not found. Install PyMySQL or MySQL-python"
)
return
self.info('using mysql driver="{}"'.format(driver))
return driver
def _test_db_connection(self, version_warn):
self.connected = True
self.create_session()
try:
self.info("testing database connection {}".format(self.test_func))
vers = getattr(self, self.test_func)(reraise=True)
if version_warn:
self._version_warn_hook()
connected = True
except OperationalError:
self.warning("Operational connection failed to {}".format(self.public_url))
connected = False
self._test_connection_enabled = False
except Exception as e:
self.debug_exception()
self.warning(
"connection failed to {} exception={}".format(self.public_url, e)
)
connected = False
finally:
self.info("closing test session")
self.close_session()
return connected
def _version_warn_hook(self):
pass
# def test_version(self):
# ver = getattr(self, self.version_func)()
# ver = ver.version_num
# aver = version.__alembic__
# if ver != aver:
# return 'Database is out of data. Pychron ver={}, Database ver={}'.format(aver, ver)
def _add_item(self, obj):
sess = self.session
if sess:
sess.add(obj)
try:
if self.autoflush:
sess.flush()
self.modified = True
self._trying_to_add = True
if not self.autocommit and self.commit_on_add:
sess.commit()
return obj
except SQLAlchemyError as e:
import traceback
self.debug(
"add_item exception {} {}".format(obj, traceback.format_exc())
)
sess.rollback()
if self.reraise:
raise
else:
self.critical("No session")
def _add_unique(self, item, attr, name):
nitem = getattr(self, "get_{}".format(attr))(name)
if nitem is None:
self.info("adding {}= {}".format(attr, name))
self._add_item(item)
nitem = item
return nitem
def _get_date_range(self, q, asc, desc, hours=0):
lan = q.order_by(asc).first()
han = q.order_by(desc).first()
lan = datetime.now() if not lan else lan.timestamp
han = datetime.now() if not han else han.timestamp
td = timedelta(hours=hours)
return lan - td, han + td
def _delete_item(self, value, name=None):
if name is not None:
func = getattr(self, "get_{}".format(name))
item = func(value)
else:
item = value
if item:
self.debug("deleting value={},name={},item={}".format(value, name, item))
self.session.delete(item)
def _retrieve_items(
self,
table,
joins=None,
filters=None,
limit=None,
order=None,
distinct_=False,
query_hook=None,
reraise=False,
func="all",
group_by=None,
verbose_query=False,
):
sess = self.session
if sess is None or isinstance(sess, MockSession):
self.debug("USING MOCKSESSION************** {}".format(sess))
return []
if distinct_:
if isinstance(distinct_, bool):
q = sess.query(distinct(table))
else:
q = sess.query(distinct(distinct_))
elif isinstance(table, tuple):
q = sess.query(*table)
else:
q = sess.query(table)
if joins:
try:
for ji in joins:
if ji != table:
q = q.join(ji)
except InvalidRequestError:
if reraise:
raise
if filters is not None:
for fi in filters:
q = q.filter(fi)
if order is not None:
if not isinstance(order, tuple):
order = (order,)
q = q.order_by(*order)
if group_by is not None:
if not isinstance(order, tuple):
group_by = (group_by,)
q = q.group_by(*group_by)
if limit is not None:
q = q.limit(limit)
if query_hook:
q = query_hook(q)
if verbose_query or self.verbose_retrieve_query:
# print compile_query(q)
self.debug(compile_query(q))
items = self._query(q, func, reraise)
if items is None:
items = []
return items
def _retrieve_first(self, table, value=None, key="name", order_by=None):
if value is not None:
if not isinstance(value, (str, int, six.text_type, int, float)):
return value
q = self.session.query(table)
if value is not None:
q = q.filter(getattr(table, key) == value)
try:
if order_by is not None:
q = q.order_by(order_by)
return q.first()
except SQLAlchemyError as e:
print("execption first", e)
return
def _query_all(self, q, **kw):
ret = self._query(q, "all", **kw)
return ret or []
def _query_first(self, q, **kw):
return self._query(q, "first", **kw)
def _query_one(self, q, **kw):
q = q.limit(1)
return self._query(q, "one", **kw)
def _query(self, q, func, reraise=False, verbose_query=False):
if verbose_query:
try:
cq = compile_query(q)
self.debug(cq)
except BaseException:
cq = "Query failed to compile"
self.debug_exception()
# print compile_query(q)
f = getattr(q, func)
try:
return f()
except NoResultFound:
if verbose_query:
self.info("no results found for query -- {}".format(cq))
except OperationalError as e:
self.debug("_query operation exception")
self.debug_exception()
except SQLAlchemyError as e:
if self.verbose:
self.debug("_query exception {}".format(e))
try:
self.rollback()
self.reset_connection()
self.connect()
except BaseException:
pass
if reraise:
raise e
def _append_filters(self, f, kw):
filters = kw.get("filters", [])
if isinstance(f, (tuple, list)):
filters.extend(f)
else:
filters.append(f)
kw["filters"] = filters
return kw
def _append_joins(self, f, kw):
joins = kw.get("joins", [])
if isinstance(f, (tuple, list)):
joins.extend(f)
else:
joins.append(f)
kw["joins"] = joins
return kw
def _retrieve_item(
self,
table,
value,
key="name",
last=None,
joins=None,
filters=None,
options=None,
verbose=True,
verbose_query=False,
):
if not isinstance(value, (str, int, six.text_type, int, float, list, tuple)):
return value
if not isinstance(value, (list, tuple)):
value = (value,)
if not isinstance(key, (list, tuple)):
key = (key,)
def __retrieve(s):
q = s.query(table)
if joins:
try:
for ji in joins:
if ji != table:
q = q.join(ji)
except InvalidRequestError:
pass
if filters is not None:
for fi in filters:
q = q.filter(fi)
for k, v in zip(key, value):
q = q.filter(getattr(table, k) == v)
if last:
q = q.order_by(last)
if verbose_query or self.verbose_retrieve_query:
self.debug(compile_query(q))
ntries = 3
import traceback
for i in range(ntries):
try:
return q.one()
except (DBAPIError, OperationalError, StatementError):
self.debug(traceback.format_exc())
s.rollback()
continue
except MultipleResultsFound:
if verbose:
self.debug(
"multiples row found for {} {} {}. Trying to get last row".format(
table.__tablename__, key, value
)
)
try:
if hasattr(table, "id"):
q = q.order_by(table.id.desc())
return q.limit(1).all()[-1]
except (SQLAlchemyError, IndexError, AttributeError) as e:
if verbose:
self.debug(
"no rows for {} {} {}".format(
table.__tablename__, key, value
)
)
break
except NoResultFound:
if verbose and self.verbose:
self.debug(
"no row found for {} {} {}".format(
table.__tablename__, key, value
)
)
break
close = False
if self.session is None:
self.create_session()
close = True
ret = __retrieve(self.session)
if close:
self.close_session()
return ret
def _get_items(
self,
table,
gtables,
join_table=None,
filter_str=None,
limit=None,
order=None,
key=None,
):
if isinstance(join_table, str):
join_table = gtables[join_table]
q = self._get_query(table, join_table=join_table, filter_str=filter_str)
if order:
for o in order if isinstance(order, list) else [order]:
q = q.order_by(o)
if limit:
q = q.limit(limit)
# reorder based on id
if order:
q = q.from_self()
q = q.order_by(table.id)
res = q.all()
if key:
return [getattr(ri, key) for ri in res]
return res
class PathDatabaseAdapter(DatabaseAdapter):
path_table = None
def add_path(self, rec, path, **kw):
if self.path_table is None:
raise NotImplementedError
kw = self._get_path_keywords(path, kw)
p = self.path_table(**kw)
rec.path = p
return p
def _get_path_keywords(self, path, args):
n = os.path.basename(path)
r = os.path.dirname(path)
args["root"] = r
args["filename"] = n
return args
class SQLiteDatabaseAdapter(DatabaseAdapter):
kind = "sqlite"
def build_database(self):
self.connect(test=False)
if not os.path.isfile(self.path):
meta = MetaData()
self._build_database(self.session, meta)
def _build_database(self, sess, meta):
raise NotImplementedError
# ============= EOF =============================================