249 lines
9.3 KiB
Python
249 lines
9.3 KiB
Python
|
# -*- coding: utf-8 -*-
|
||
|
import cherrypy
|
||
|
import logging
|
||
|
import threading
|
||
|
import psycopg2
|
||
|
import lybmods.lybcfg as lybcfg
|
||
|
import pickle as pickle
|
||
|
import lybmods.lybtools as lybtools
|
||
|
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
|
||
|
psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY)
|
||
|
|
||
|
logger = logging.getLogger('Session')
|
||
|
|
||
|
class PgsqlSession(cherrypy.lib.sessions.Session):
|
||
|
sess_table_name = lybcfg.sess_table_name
|
||
|
sess_data_table_name = lybcfg.sess_data_table_name
|
||
|
connect_arguments = "dbname={0} user={1} password={2} host={3}".format(lybcfg.dbname, lybcfg.dbuser, lybcfg.dbpass, lybcfg.dbhost)
|
||
|
|
||
|
missing = False
|
||
|
"True if the session requested by the client did not exist."
|
||
|
|
||
|
_database = None
|
||
|
|
||
|
def __init__(self, sid=None, **kwargs):
|
||
|
logger.debug('Initializing PgsqlSession with %r' % kwargs)
|
||
|
for k, v in list(kwargs.items()):
|
||
|
setattr(PgsqlSession, k, v)
|
||
|
self.db = self.get_db()
|
||
|
self.cursor = self.db.cursor()
|
||
|
super(PgsqlSession, self).__init__(sid, **kwargs)
|
||
|
|
||
|
@classmethod
|
||
|
def get_db(cls):
|
||
|
##
|
||
|
## Use thread-local connections
|
||
|
local = threading.local()
|
||
|
if hasattr(local, 'db'):
|
||
|
return local.db
|
||
|
else:
|
||
|
logger.debug("Connecting to %r" % cls.connect_arguments)
|
||
|
db = psycopg2.connect(cls.connect_arguments)
|
||
|
local.db = db
|
||
|
|
||
|
return db
|
||
|
|
||
|
def load(self):
|
||
|
"""Copy stored session data into this session instance."""
|
||
|
exptime = self._load()
|
||
|
# data is either None or a tuple (session_data, expiration_time)
|
||
|
if exptime is None or exptime < self.now().utcnow():
|
||
|
if self.debug:
|
||
|
cherrypy.log('Expired session, flushing data', 'TOOLS.SESSIONS')
|
||
|
self.regenerate()
|
||
|
self.loaded = True
|
||
|
|
||
|
# Stick the clean_thread in the class, not the instance.
|
||
|
# The instances are created and destroyed per-request.
|
||
|
cls = self.__class__
|
||
|
if self.clean_freq and not cls.clean_thread:
|
||
|
# clean_up is in instancemethod and not a classmethod,
|
||
|
# so that tool config can be accessed inside the method.
|
||
|
t = cherrypy.process.plugins.Monitor(
|
||
|
cherrypy.engine, self.clean_up, self.clean_freq * 60,
|
||
|
name='Session cleanup')
|
||
|
t.subscribe()
|
||
|
cls.clean_thread = t
|
||
|
t.start()
|
||
|
|
||
|
def _load(self):
|
||
|
logger.debug('_load %r' % self)
|
||
|
# Select session data from table
|
||
|
self.cursor.execute('select expiration_time from %s '
|
||
|
'where id = %%s' % PgsqlSession.sess_table_name, (self.id,))
|
||
|
row = self.cursor.fetchone()
|
||
|
if row:
|
||
|
expiration_time = row[0].utcfromtimestamp(row[0].timestamp())
|
||
|
return expiration_time
|
||
|
else:
|
||
|
return None
|
||
|
|
||
|
def _save(self, expiration_time):
|
||
|
logger.debug('_save %r' % self)
|
||
|
self.cursor.execute('select count(*) from %s where id = %%s and expiration_time > now()' % PgsqlSession.sess_table_name, (self.id,))
|
||
|
(count,) = self.cursor.fetchone()
|
||
|
if count:
|
||
|
self.cursor.execute('update %s set expiration_time = %%s where id = %%s' % PgsqlSession.sess_table_name,
|
||
|
(expiration_time, self.id))
|
||
|
else:
|
||
|
self.cursor.execute('insert into %s (expiration_time, id) values (%%s, %%s)' % PgsqlSession.sess_table_name,
|
||
|
(expiration_time, self.id))
|
||
|
self.db.commit()
|
||
|
|
||
|
def acquire_lock(self):
|
||
|
logger.debug('acquire_lock %r' % self)
|
||
|
self.locked = True
|
||
|
self.cursor.execute('select id from %s where id = %%s for update' % PgsqlSession.sess_table_name,
|
||
|
(self.id,))
|
||
|
self.db.commit()
|
||
|
|
||
|
def release_lock(self):
|
||
|
logger.debug('release_lock %r' % self)
|
||
|
self.locked = False
|
||
|
self.db.commit()
|
||
|
|
||
|
def clean_up(self):
|
||
|
logger.debug('clean_up %r' % self)
|
||
|
|
||
|
self.cursor.execute('DELETE FROM %s WHERE expiration_time < now()' % PgsqlSession.sess_table_name)
|
||
|
self.db.commit()
|
||
|
|
||
|
def _delete(self):
|
||
|
logger.debug('_delete %r' % self)
|
||
|
self.cursor.execute('DELETE FROM %s WHERE id=%%s' % PgsqlSession.sess_table_name, (self.id,))
|
||
|
self.db.commit()
|
||
|
|
||
|
|
||
|
def _exists(self):
|
||
|
# Select session data from table
|
||
|
self.cursor.execute('select count(*) from %s '
|
||
|
'where id = %%s and expiration_time > now()' % PgsqlSession.sess_table_name, (self.id,))
|
||
|
(count,) = self.cursor.fetchone()
|
||
|
logger.debug('_exists %r (%r)' % (self, bool(count)))
|
||
|
return bool(count)
|
||
|
|
||
|
def __getitem__(self, key):
|
||
|
if not self.loaded: self.load()
|
||
|
self.cursor.execute('SELECT value FROM %s WHERE key=%%s AND sid=%%s;' % PgsqlSession.sess_data_table_name, (key, self.id))
|
||
|
try:
|
||
|
res = self.cursor.fetchall()[0][0]
|
||
|
val = pickle.loads(res)
|
||
|
return val
|
||
|
except:
|
||
|
raise KeyError(key)
|
||
|
|
||
|
def __setitem__(self, key, value):
|
||
|
if not self.loaded: self.load()
|
||
|
self.cursor.execute('SELECT count(*) FROM %s WHERE key = %%s AND sid = %%s;' % PgsqlSession.sess_data_table_name, (key, self.id))
|
||
|
count = self.cursor.fetchone()[0]
|
||
|
if count:
|
||
|
self.cursor.execute('UPDATE %s SET value = %%s WHERE key = %%s AND sid = %%s;' % PgsqlSession.sess_data_table_name,
|
||
|
(pickle.dumps(value), key, self.id))
|
||
|
else:
|
||
|
self.cursor.execute('INSERT INTO %s (value, key, sid) VALUES (%%s, %%s, %%s);' % PgsqlSession.sess_data_table_name,
|
||
|
(pickle.dumps(value), key, self.id))
|
||
|
self.db.commit()
|
||
|
def __delitem__(self, key):
|
||
|
if not self.loaded: self.load()
|
||
|
self.cursor.execute('DELETE FROM %s WHERE key = %%s AND sid = %%s;' % PgsqlSession.sess_data_table_name, (key, self.id))
|
||
|
self.db.commit()
|
||
|
def pop(self, key, default=None):
|
||
|
"""Remove the specified key and return the corresponding value.
|
||
|
If key is not found, default is returned if given,
|
||
|
otherwise KeyError is raised.
|
||
|
"""
|
||
|
if not self.loaded: self.load()
|
||
|
try:
|
||
|
val = self[key]
|
||
|
except:
|
||
|
if default is None:
|
||
|
raise KeyError(key)
|
||
|
else:
|
||
|
val = default
|
||
|
return val
|
||
|
del self[key]
|
||
|
return val
|
||
|
|
||
|
def __contains__(self, key):
|
||
|
if not self.loaded: self.load()
|
||
|
self.cursor.execute('SELECT count(*) FROM %s WHERE key = %%s AND sid = %%s;' % PgsqlSession.sess_data_table_name, (key, self.id))
|
||
|
return bool(self.cursor.fetchone()[0])
|
||
|
|
||
|
if hasattr({}, 'has_key'):
|
||
|
def has_key(self, key):
|
||
|
"""D.has_key(k) -> True if D has a key k, else False."""
|
||
|
if not self.loaded: self.load()
|
||
|
return self.__contains__(key)
|
||
|
|
||
|
def get(self, key, default=None):
|
||
|
"""D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None."""
|
||
|
if not self.loaded: self.load()
|
||
|
try:
|
||
|
val = self[key]
|
||
|
return val
|
||
|
except:
|
||
|
return default
|
||
|
|
||
|
def update(self, d):
|
||
|
"""D.update(E) -> None. Update D from E: for k in E: D[k] = E[k]."""
|
||
|
if not self.loaded: self.load()
|
||
|
for k in d:
|
||
|
self[k] = d[k]
|
||
|
|
||
|
def setdefault(self, key, default=None):
|
||
|
"""D.setdefault(k[,d]) -> D.get(k,d), also set D[k]=d if k not in D."""
|
||
|
if not self.loaded: self.load()
|
||
|
try:
|
||
|
val = self[key]
|
||
|
except:
|
||
|
self[key] = val = default
|
||
|
return val
|
||
|
|
||
|
def clear(self):
|
||
|
"""D.clear() -> None. Remove all items from D."""
|
||
|
if not self.loaded: self.load()
|
||
|
self.cursor.execute('DELETE FROM %s WHERE sid=%%s AND key != %%s;' % PgsqlSession.sess_data_table_name, (self.id, lybtools.SESSION_KEY))
|
||
|
self.db.commit()
|
||
|
|
||
|
def keys(self):
|
||
|
"""D.keys() -> list of D's keys."""
|
||
|
if not self.loaded: self.load()
|
||
|
k = []
|
||
|
self.cursor.execute('SELECT key FROM %s WHERE sid = %%s;' % PgsqlSession.sess_data_table_name, (self.id,))
|
||
|
res = self.cursor.fetchall()
|
||
|
if res != []:
|
||
|
for i in res:
|
||
|
k.append(i[0])
|
||
|
return k
|
||
|
|
||
|
def items(self):
|
||
|
"""D.items() -> list of D's (key, value) pairs, as 2-tuples."""
|
||
|
if not self.loaded: self.load()
|
||
|
i = []
|
||
|
self.cursor.execute('SELECT key, value FROM %s WHERE sid = %%s;' % PgsqlSession.sess_data_table_name, (self.id,))
|
||
|
res = self.cursor.fetchall()
|
||
|
if res != []:
|
||
|
for k, v in res:
|
||
|
i.append({k, pickle.loads(v)})
|
||
|
return i
|
||
|
|
||
|
def values(self):
|
||
|
"""D.values() -> list of D's values."""
|
||
|
if not self.loaded: self.load()
|
||
|
v = []
|
||
|
self.cursor.execute('SELECT value FROM %s WHERE sid = %%s;' % PgsqlSession.sess_data_table_name, (self.id,))
|
||
|
res = self.cursor.fetchall()
|
||
|
if res != []:
|
||
|
for i in res:
|
||
|
v.append(pickle.loads(i[0]))
|
||
|
return v
|
||
|
|
||
|
def __del__(self):
|
||
|
logger.debug('__del__ %r' % self)
|
||
|
self.db.commit()
|
||
|
self.db.close()
|
||
|
self.db = None
|
||
|
|
||
|
def __repr__(self):
|
||
|
return '<PgsqlSession %r>' % (self.id,)
|