Blob Blame History Raw

# cache.py - caching layer for pynslcd
#
# Copyright (C) 2012, 2013 Arthur de Jong
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301 USA

import datetime
import os
import sys

import sqlite3


# TODO: probably create a config table
# FIXME: have some way to remove stale entries from the cache if all items from LDAP are queried (perhas use TTL from all request)


class regroup(object):

    def __init__(self, results, group_by=None, group_column=None):
        """Regroup the results in the group column by the key columns."""
        self.group_by = tuple(group_by)
        self.group_column = group_column
        self.it = iter(results)
        self.tgtkey = self.currkey = self.currvalue = object()

    def keyfunc(self, row):
        return tuple(row[x] for x in self.group_by)

    def __iter__(self):
        return self

    def next(self):
        # find a start row
        while self.currkey == self.tgtkey:
            self.currvalue = next(self.it)    # Exit on StopIteration
            self.currkey = self.keyfunc(self.currvalue)
        self.tgtkey = self.currkey
        # turn the result row into a list of columns
        row = list(self.currvalue)
        # replace the group column
        row[self.group_column] = list(self._grouper(self.tgtkey))
        return row

    def _grouper(self, tgtkey):
        """Generate the group columns."""
        while self.currkey == tgtkey:
            value = self.currvalue[self.group_column]
            if value is not None:
                yield value
            self.currvalue = next(self.it)    # Exit on StopIteration
            self.currkey = self.keyfunc(self.currvalue)


class Query(object):
    """Helper class to build an SQL query for the cache."""

    def __init__(self, query):
        self.query = query
        self.wheres = []
        self.parameters = []

    def add_where(self, where, parameters):
        self.wheres.append(where)
        self.parameters += parameters

    def execute(self, con):
        query = self.query
        if self.wheres:
            query += ' WHERE ' + ' AND '.join(self.wheres)
        cursor = con.cursor()
        return cursor.execute(query, self.parameters)


class Cache(object):
    """The description of the cache."""

    retrieve_sql = None
    retrieve_by = dict()
    group_by = ()
    group_columns = ()

    def __init__(self):
        self.con = _get_connection()
        self.db = sys.modules[self.__module__].__name__
        if not hasattr(self, 'tables'):
            self.tables = ['%s_cache' % self.db]
        self.create()

    def create(self):
        """Create the needed tables if neccesary."""
        self.con.executescript(self.create_sql)

    def store(self, *values):
        """Store the values in the cache for the specified table.
        The order of the values is the order returned by the Reques.convert()
        function."""
        # split the values into simple (flat) values and one-to-many values
        simple_values = []
        multi_values = []
        for v in values:
            if isinstance(v, (list, tuple, set)):
                multi_values.append(v)
            else:
                simple_values.append(v)
        # insert the simple values
        simple_values.append(datetime.datetime.now())
        args = ', '.join(len(simple_values) * ('?', ))
        self.con.execute('''
            INSERT OR REPLACE INTO %s
            VALUES
              (%s)
            ''' % (self.tables[0], args), simple_values)
        # insert the one-to-many values
        for n, vlist in enumerate(multi_values):
            self.con.execute('''
                DELETE FROM %s
                WHERE `%s` = ?
                ''' % (self.tables[n + 1], self.db), (values[0], ))
            self.con.executemany('''
                INSERT INTO %s
                VALUES
                  (?, ?)
                ''' % (self.tables[n + 1]), ((values[0], x) for x in vlist))

    def retrieve(self, parameters):
        """Retrieve all items from the cache based on the parameters
        supplied."""
        query = Query(self.retrieve_sql or '''
            SELECT *
            FROM %s
            ''' % self.tables[0])
        if parameters:
            for k, v in parameters.items():
                where = self.retrieve_by.get(k, '`%s`.`%s` = ?' % (self.tables[0], k))
                query.add_where(where, where.count('?') * [v])
        # group by
        # FIXME: find a nice way to turn group_by and group_columns into names
        results = query.execute(self.con)
        group_by = list(self.group_by + self.group_columns)
        for column in self.group_columns[::-1]:
            group_by.pop()
            results = regroup(results, group_by, column)
        # strip the mtime from the results
        return (list(x)[:-1] for x in results)

    def __enter__(self):
        return self.con.__enter__();

    def __exit__(self, *args):
        return self.con.__exit__(*args);


# the connection to the sqlite database
_connection = None


# FIXME: make tread safe (is this needed the way the caches are initialised?)
def _get_connection():
    global _connection
    if _connection is None:
        filename = '/tmp/pynslcd_cache.sqlite'
        dirname = os.path.dirname(filename)
        if not os.path.isdir(dirname):
            os.mkdir(dirname)
        connection = sqlite3.connect(
            filename, detect_types=sqlite3.PARSE_DECLTYPES,
            check_same_thread=False)
        connection.row_factory = sqlite3.Row
        # initialise connection properties
        connection.executescript('''
            -- store temporary tables in memory
            PRAGMA temp_store = MEMORY;
            -- disable sync() on database (corruption on disk failure)
            PRAGMA synchronous = OFF;
            -- put journal in memory (corruption if crash during transaction)
            PRAGMA journal_mode = MEMORY;
            ''')
        _connection = connection
    return _connection