"""TestCases for multi-threaded access to a DB. """ import os import sys import time import errno from random import random DASH = '-' try: WindowsError except NameError: class WindowsError(Exception): pass import unittest from test_all import db, dbutils, test_support, verbose, have_threads, \ get_new_environment_path, get_new_database_path if have_threads : from threading import Thread if sys.version_info[0] < 3 : from threading import currentThread else : from threading import current_thread as currentThread #---------------------------------------------------------------------- class BaseThreadedTestCase(unittest.TestCase): dbtype = db.DB_UNKNOWN # must be set in derived class dbopenflags = 0 dbsetflags = 0 envflags = 0 def setUp(self): if verbose: dbutils._deadlock_VerboseFile = sys.stdout self.homeDir = get_new_environment_path() self.env = db.DBEnv() self.setEnvOpts() self.env.open(self.homeDir, self.envflags | db.DB_CREATE) self.filename = self.__class__.__name__ + '.db' self.d = db.DB(self.env) if self.dbsetflags: self.d.set_flags(self.dbsetflags) self.d.open(self.filename, self.dbtype, self.dbopenflags|db.DB_CREATE) def tearDown(self): self.d.close() self.env.close() test_support.rmtree(self.homeDir) def setEnvOpts(self): pass def makeData(self, key): return DASH.join([key] * 5) #---------------------------------------------------------------------- class ConcurrentDataStoreBase(BaseThreadedTestCase): dbopenflags = db.DB_THREAD envflags = db.DB_THREAD | db.DB_INIT_CDB | db.DB_INIT_MPOOL readers = 0 # derived class should set writers = 0 records = 1000 def test01_1WriterMultiReaders(self): if verbose: print '\n', '-=' * 30 print "Running %s.test01_1WriterMultiReaders..." % \ self.__class__.__name__ keys=range(self.records) import random random.shuffle(keys) records_per_writer=self.records//self.writers readers_per_writer=self.readers//self.writers self.assertEqual(self.records,self.writers*records_per_writer) self.assertEqual(self.readers,self.writers*readers_per_writer) self.assertTrue((records_per_writer%readers_per_writer)==0) readers = [] for x in xrange(self.readers): rt = Thread(target = self.readerThread, args = (self.d, x), name = 'reader %d' % x, )#verbose = verbose) if sys.version_info[0] < 3 : rt.setDaemon(True) else : rt.daemon = True readers.append(rt) writers=[] for x in xrange(self.writers): a=keys[records_per_writer*x:records_per_writer*(x+1)] a.sort() # Generate conflicts b=readers[readers_per_writer*x:readers_per_writer*(x+1)] wt = Thread(target = self.writerThread, args = (self.d, a, b), name = 'writer %d' % x, )#verbose = verbose) writers.append(wt) for t in writers: if sys.version_info[0] < 3 : t.setDaemon(True) else : t.daemon = True t.start() for t in writers: t.join() for t in readers: t.join() def writerThread(self, d, keys, readers): if sys.version_info[0] < 3 : name = currentThread().getName() else : name = currentThread().name if verbose: print "%s: creating records %d - %d" % (name, start, stop) count=len(keys)//len(readers) count2=count for x in keys : key = '%04d' % x dbutils.DeadlockWrap(d.put, key, self.makeData(key), max_retries=12) if verbose and x % 100 == 0: print "%s: records %d - %d finished" % (name, start, x) count2-=1 if not count2 : readers.pop().start() count2=count if verbose: print "%s: finished creating records" % name if verbose: print "%s: thread finished" % name def readerThread(self, d, readerNum): if sys.version_info[0] < 3 : name = currentThread().getName() else : name = currentThread().name for i in xrange(5) : c = d.cursor() count = 0 rec = c.first() while rec: count += 1 key, data = rec self.assertEqual(self.makeData(key), data) rec = c.next() if verbose: print "%s: found %d records" % (name, count) c.close() if verbose: print "%s: thread finished" % name class BTreeConcurrentDataStore(ConcurrentDataStoreBase): dbtype = db.DB_BTREE writers = 2 readers = 10 records = 1000 class HashConcurrentDataStore(ConcurrentDataStoreBase): dbtype = db.DB_HASH writers = 2 readers = 10 records = 1000 #---------------------------------------------------------------------- class SimpleThreadedBase(BaseThreadedTestCase): dbopenflags = db.DB_THREAD envflags = db.DB_THREAD | db.DB_INIT_MPOOL | db.DB_INIT_LOCK readers = 10 writers = 2 records = 1000 def setEnvOpts(self): self.env.set_lk_detect(db.DB_LOCK_DEFAULT) def test02_SimpleLocks(self): if verbose: print '\n', '-=' * 30 print "Running %s.test02_SimpleLocks..." % self.__class__.__name__ keys=range(self.records) import random random.shuffle(keys) records_per_writer=self.records//self.writers readers_per_writer=self.readers//self.writers self.assertEqual(self.records,self.writers*records_per_writer) self.assertEqual(self.readers,self.writers*readers_per_writer) self.assertTrue((records_per_writer%readers_per_writer)==0) readers = [] for x in xrange(self.readers): rt = Thread(target = self.readerThread, args = (self.d, x), name = 'reader %d' % x, )#verbose = verbose) if sys.version_info[0] < 3 : rt.setDaemon(True) else : rt.daemon = True readers.append(rt) writers = [] for x in xrange(self.writers): a=keys[records_per_writer*x:records_per_writer*(x+1)] a.sort() # Generate conflicts b=readers[readers_per_writer*x:readers_per_writer*(x+1)] wt = Thread(target = self.writerThread, args = (self.d, a, b), name = 'writer %d' % x, )#verbose = verbose) writers.append(wt) for t in writers: if sys.version_info[0] < 3 : t.setDaemon(True) else : t.daemon = True t.start() for t in writers: t.join() for t in readers: t.join() def writerThread(self, d, keys, readers): if sys.version_info[0] < 3 : name = currentThread().getName() else : name = currentThread().name if verbose: print "%s: creating records %d - %d" % (name, start, stop) count=len(keys)//len(readers) count2=count for x in keys : key = '%04d' % x dbutils.DeadlockWrap(d.put, key, self.makeData(key), max_retries=12) if verbose and x % 100 == 0: print "%s: records %d - %d finished" % (name, start, x) count2-=1 if not count2 : readers.pop().start() count2=count if verbose: print "%s: thread finished" % name def readerThread(self, d, readerNum): if sys.version_info[0] < 3 : name = currentThread().getName() else : name = currentThread().name c = d.cursor() count = 0 rec = dbutils.DeadlockWrap(c.first, max_retries=10) while rec: count += 1 key, data = rec self.assertEqual(self.makeData(key), data) rec = dbutils.DeadlockWrap(c.next, max_retries=10) if verbose: print "%s: found %d records" % (name, count) c.close() if verbose: print "%s: thread finished" % name class BTreeSimpleThreaded(SimpleThreadedBase): dbtype = db.DB_BTREE class HashSimpleThreaded(SimpleThreadedBase): dbtype = db.DB_HASH #---------------------------------------------------------------------- class ThreadedTransactionsBase(BaseThreadedTestCase): dbopenflags = db.DB_THREAD | db.DB_AUTO_COMMIT envflags = (db.DB_THREAD | db.DB_INIT_MPOOL | db.DB_INIT_LOCK | db.DB_INIT_LOG | db.DB_INIT_TXN ) readers = 0 writers = 0 records = 2000 txnFlag = 0 def setEnvOpts(self): #self.env.set_lk_detect(db.DB_LOCK_DEFAULT) pass def test03_ThreadedTransactions(self): if verbose: print '\n', '-=' * 30 print "Running %s.test03_ThreadedTransactions..." % \ self.__class__.__name__ keys=range(self.records) import random random.shuffle(keys) records_per_writer=self.records//self.writers readers_per_writer=self.readers//self.writers self.assertEqual(self.records,self.writers*records_per_writer) self.assertEqual(self.readers,self.writers*readers_per_writer) self.assertTrue((records_per_writer%readers_per_writer)==0) readers=[] for x in xrange(self.readers): rt = Thread(target = self.readerThread, args = (self.d, x), name = 'reader %d' % x, )#verbose = verbose) if sys.version_info[0] < 3 : rt.setDaemon(True) else : rt.daemon = True readers.append(rt) writers = [] for x in xrange(self.writers): a=keys[records_per_writer*x:records_per_writer*(x+1)] b=readers[readers_per_writer*x:readers_per_writer*(x+1)] wt = Thread(target = self.writerThread, args = (self.d, a, b), name = 'writer %d' % x, )#verbose = verbose) writers.append(wt) dt = Thread(target = self.deadlockThread) if sys.version_info[0] < 3 : dt.setDaemon(True) else : dt.daemon = True dt.start() for t in writers: if sys.version_info[0] < 3 : t.setDaemon(True) else : t.daemon = True t.start() for t in writers: t.join() for t in readers: t.join() self.doLockDetect = False dt.join() def writerThread(self, d, keys, readers): if sys.version_info[0] < 3 : name = currentThread().getName() else : name = currentThread().name count=len(keys)//len(readers) while len(keys): try: txn = self.env.txn_begin(None, self.txnFlag) keys2=keys[:count] for x in keys2 : key = '%04d' % x d.put(key, self.makeData(key), txn) if verbose and x % 100 == 0: print "%s: records %d - %d finished" % (name, start, x) txn.commit() keys=keys[count:] readers.pop().start() except (db.DBLockDeadlockError, db.DBLockNotGrantedError), val: if verbose: if sys.version_info < (2, 6) : print "%s: Aborting transaction (%s)" % (name, val[1]) else : print "%s: Aborting transaction (%s)" % (name, val.args[1]) txn.abort() if verbose: print "%s: thread finished" % name def readerThread(self, d, readerNum): if sys.version_info[0] < 3 : name = currentThread().getName() else : name = currentThread().name finished = False while not finished: try: txn = self.env.txn_begin(None, self.txnFlag) c = d.cursor(txn) count = 0 rec = c.first() while rec: count += 1 key, data = rec self.assertEqual(self.makeData(key), data) rec = c.next() if verbose: print "%s: found %d records" % (name, count) c.close() txn.commit() finished = True except (db.DBLockDeadlockError, db.DBLockNotGrantedError), val: if verbose: if sys.version_info < (2, 6) : print "%s: Aborting transaction (%s)" % (name, val[1]) else : print "%s: Aborting transaction (%s)" % (name, val.args[1]) c.close() txn.abort() if verbose: print "%s: thread finished" % name def deadlockThread(self): self.doLockDetect = True while self.doLockDetect: time.sleep(0.05) try: aborted = self.env.lock_detect( db.DB_LOCK_RANDOM, db.DB_LOCK_CONFLICT) if verbose and aborted: print "deadlock: Aborted %d deadlocked transaction(s)" \ % aborted except db.DBError: pass class BTreeThreadedTransactions(ThreadedTransactionsBase): dbtype = db.DB_BTREE writers = 2 readers = 10 records = 1000 class HashThreadedTransactions(ThreadedTransactionsBase): dbtype = db.DB_HASH writers = 2 readers = 10 records = 1000 class BTreeThreadedNoWaitTransactions(ThreadedTransactionsBase): dbtype = db.DB_BTREE writers = 2 readers = 10 records = 1000 txnFlag = db.DB_TXN_NOWAIT class HashThreadedNoWaitTransactions(ThreadedTransactionsBase): dbtype = db.DB_HASH writers = 2 readers = 10 records = 1000 txnFlag = db.DB_TXN_NOWAIT #---------------------------------------------------------------------- def test_suite(): suite = unittest.TestSuite() if have_threads: suite.addTest(unittest.makeSuite(BTreeConcurrentDataStore)) suite.addTest(unittest.makeSuite(HashConcurrentDataStore)) suite.addTest(unittest.makeSuite(BTreeSimpleThreaded)) suite.addTest(unittest.makeSuite(HashSimpleThreaded)) suite.addTest(unittest.makeSuite(BTreeThreadedTransactions)) suite.addTest(unittest.makeSuite(HashThreadedTransactions)) suite.addTest(unittest.makeSuite(BTreeThreadedNoWaitTransactions)) suite.addTest(unittest.makeSuite(HashThreadedNoWaitTransactions)) else: print "Threads not available, skipping thread tests." return suite if __name__ == '__main__': unittest.main(defaultTest='test_suite')