# -*- coding: utf-8 -*-
"""
Tests for thread usage in lxml.etree.
"""
import re
import sys
import os.path
import unittest
import threading
this_dir = os.path.dirname(__file__)
if this_dir not in sys.path:
sys.path.insert(0, this_dir) # needed for Py3
from common_imports import etree, HelperTestCase, BytesIO, _bytes
try:
from Queue import Queue
except ImportError:
from queue import Queue # Py3
class ThreadingTestCase(HelperTestCase):
"""Threading tests"""
etree = etree
def _run_thread(self, func):
thread = threading.Thread(target=func)
thread.start()
thread.join()
def _run_threads(self, count, func, main_func=None):
sync = threading.Event()
lock = threading.Lock()
counter = dict(started=0, finished=0, failed=0)
def sync_start(func):
with lock:
started = counter['started'] + 1
counter['started'] = started
if started < count + (main_func is not None):
sync.wait(4) # wait until the other threads have started up
assert sync.is_set()
sync.set() # all waiting => go!
try:
func()
except:
with lock:
counter['failed'] += 1
raise
else:
with lock:
counter['finished'] += 1
threads = [threading.Thread(target=sync_start, args=(func,)) for _ in range(count)]
for thread in threads:
thread.start()
if main_func is not None:
sync_start(main_func)
for thread in threads:
thread.join()
self.assertEqual(0, counter['failed'])
self.assertEqual(counter['finished'], counter['started'])
def test_subtree_copy_thread(self):
tostring = self.etree.tostring
XML = self.etree.XML
xml = _bytes("<root><threadtag/></root>")
main_root = XML(_bytes("<root/>"))
def run_thread():
thread_root = XML(xml)
main_root.append(thread_root[0])
del thread_root
self._run_thread(run_thread)
self.assertEqual(xml, tostring(main_root))
def test_main_xslt_in_thread(self):
XML = self.etree.XML
style = XML(_bytes('''\
<xsl:stylesheet version="1.0"
xmlns:xsl="http://www.w3.org/1999/XSL/Transform">
<xsl:template match="*">
<foo><xsl:copy><xsl:value-of select="/a/b/text()" /></xsl:copy></foo>
</xsl:template>
</xsl:stylesheet>'''))
st = etree.XSLT(style)
result = []
def run_thread():
root = XML(_bytes('<a><b>B</b><c>C</c></a>'))
result.append( st(root) )
self._run_thread(run_thread)
self.assertEqual('''\
<?xml version="1.0"?>
<foo><a>B</a></foo>
''',
str(result[0]))
def test_thread_xslt(self):
XML = self.etree.XML
tostring = self.etree.tostring
root = XML(_bytes('<a><b>B</b><c>C</c></a>'))
def run_thread():
style = XML(_bytes('''\
<xsl:stylesheet version="1.0"
xmlns:xsl="http://www.w3.org/1999/XSL/Transform">
<xsl:template match="*">
<foo><xsl:copy><xsl:value-of select="/a/b/text()" /></xsl:copy></foo>
</xsl:template>
</xsl:stylesheet>'''))
st = etree.XSLT(style)
root.append( st(root).getroot() )
self._run_thread(run_thread)
self.assertEqual(_bytes('<a><b>B</b><c>C</c><foo><a>B</a></foo></a>'),
tostring(root))
def test_thread_xslt_parsing_error_log(self):
style = self.parse('''\
<xsl:stylesheet version="1.0"
xmlns:xsl="http://www.w3.org/1999/XSL/Transform">
<xsl:template match="tag" />
<!-- extend time for parsing + transform -->
''' + '\n'.join('<xsl:template match="tag%x" />' % i for i in range(200)) + '''
<xsl:foo />
</xsl:stylesheet>''')
self.assertRaises(etree.XSLTParseError,
etree.XSLT, style)
error_logs = []
def run_thread():
try:
etree.XSLT(style)
except etree.XSLTParseError as e:
error_logs.append(e.error_log)
else:
self.assertFalse(True, "XSLT parsing should have failed but didn't")
self._run_threads(16, run_thread)
self.assertEqual(16, len(error_logs))
last_log = None
for log in error_logs:
self.assertTrue(len(log))
if last_log is not None:
self.assertEqual(len(last_log), len(log))
self.assertEqual(4, len(log))
for error in log:
self.assertTrue(':ERROR:XSLT:' in str(error))
last_log = log
def test_thread_xslt_apply_error_log(self):
tree = self.parse('<tagFF/>')
style = self.parse('''\
<xsl:stylesheet version="1.0"
xmlns:xsl="http://www.w3.org/1999/XSL/Transform">
<xsl:template name="tag0">
<xsl:message terminate="yes">FAIL</xsl:message>
</xsl:template>
<!-- extend time for parsing + transform -->
''' + '\n'.join('<xsl:template match="tag%X" name="tag%x"> <xsl:call-template name="tag%x" /> </xsl:template>' % (i, i, i-1)
for i in range(1, 256)) + '''
</xsl:stylesheet>''')
self.assertRaises(etree.XSLTApplyError,
etree.XSLT(style), tree)
error_logs = []
def run_thread():
transform = etree.XSLT(style)
try:
transform(tree)
except etree.XSLTApplyError:
error_logs.append(transform.error_log)
else:
self.assertFalse(True, "XSLT parsing should have failed but didn't")
self._run_threads(16, run_thread)
self.assertEqual(16, len(error_logs))
last_log = None
for log in error_logs:
self.assertTrue(len(log))
if last_log is not None:
self.assertEqual(len(last_log), len(log))
self.assertEqual(1, len(log))
for error in log:
self.assertTrue(':ERROR:XSLT:' in str(error))
last_log = log
def test_thread_xslt_attr_replace(self):
# this is the only case in XSLT where the result tree can be
# modified in-place
XML = self.etree.XML
tostring = self.etree.tostring
style = self.etree.XSLT(XML(_bytes('''\
<xsl:stylesheet version="1.0"
xmlns:xsl="http://www.w3.org/1999/XSL/Transform">
<xsl:template match="*">
<root class="abc">
<xsl:copy-of select="@class" />
<xsl:attribute name="class">xyz</xsl:attribute>
</root>
</xsl:template>
</xsl:stylesheet>''')))
result = []
def run_thread():
root = XML(_bytes('<ROOT class="ABC" />'))
result.append( style(root).getroot() )
self._run_thread(run_thread)
self.assertEqual(_bytes('<root class="xyz"/>'),
tostring(result[0]))
def test_thread_create_xslt(self):
XML = self.etree.XML
tostring = self.etree.tostring
root = XML(_bytes('<a><b>B</b><c>C</c></a>'))
stylesheets = []
def run_thread():
style = XML(_bytes('''\
<xsl:stylesheet
xmlns:xsl="http://www.w3.org/1999/XSL/Transform"
version="1.0">
<xsl:output method="xml" />
<xsl:template match="/">
<div id="test">
<xsl:apply-templates/>
</div>
</xsl:template>
</xsl:stylesheet>'''))
stylesheets.append( etree.XSLT(style) )
self._run_thread(run_thread)
st = stylesheets[0]
result = tostring( st(root) )
self.assertEqual(_bytes('<div id="test">BC</div>'),
result)
def test_thread_error_log(self):
XML = self.etree.XML
expected_error = [self.etree.ErrorTypes.ERR_TAG_NAME_MISMATCH]
children = "<a>test</a>" * 100
def parse_error_test(thread_no):
tag = "tag%d" % thread_no
xml = "<%s>%s</%s>" % (tag, children, tag.upper())
parser = self.etree.XMLParser()
for _ in range(10):
errors = None
try:
XML(xml, parser)
except self.etree.ParseError:
e = sys.exc_info()[1]
errors = e.error_log.filter_types(expected_error)
self.assertTrue(errors, "Expected error not found")
for error in errors:
self.assertTrue(
tag in error.message and tag.upper() in error.message,
"%s and %s not found in '%s'" % (
tag, tag.upper(), error.message))
self.etree.clear_error_log()
threads = []
for thread_no in range(1, 10):
t = threading.Thread(target=parse_error_test,
args=(thread_no,))
threads.append(t)
t.start()
parse_error_test(0)
for t in threads:
t.join()
def test_thread_mix(self):
XML = self.etree.XML
Element = self.etree.Element
SubElement = self.etree.SubElement
tostring = self.etree.tostring
xml = _bytes('<a><b>B</b><c xmlns="test">C</c></a>')
root = XML(xml)
fragment = XML(_bytes("<other><tags/></other>"))
result = self.etree.Element("{myns}root", att = "someval")
def run_XML():
thread_root = XML(xml)
result.append(thread_root[0])
result.append(thread_root[-1])
def run_parse():
thread_root = self.etree.parse(BytesIO(xml)).getroot()
result.append(thread_root[0])
result.append(thread_root[-1])
def run_move_main():
result.append(fragment[0])
def run_build():
result.append(
Element("{myns}foo", attrib={'{test}attr':'val'}))
SubElement(result, "{otherns}tasty")
def run_xslt():
style = XML(_bytes('''\
<xsl:stylesheet version="1.0"
xmlns:xsl="http://www.w3.org/1999/XSL/Transform">
<xsl:template match="*">
<xsl:copy><foo><xsl:value-of select="/a/b/text()" /></foo></xsl:copy>
</xsl:template>
</xsl:stylesheet>'''))
st = etree.XSLT(style)
result.append( st(root).getroot() )
for test in (run_XML, run_parse, run_move_main, run_xslt, run_build):
tostring(result)
self._run_thread(test)
self.assertEqual(
_bytes('<ns0:root xmlns:ns0="myns" att="someval"><b>B</b>'
'<c xmlns="test">C</c><b>B</b><c xmlns="test">C</c><tags/>'
'<a><foo>B</foo></a>'
'<ns0:foo xmlns:ns1="test" ns1:attr="val"/>'
'<ns1:tasty xmlns:ns1="otherns"/></ns0:root>'),
tostring(result))
def strip_first():
root = Element("newroot")
root.append(result[0])
while len(result):
self._run_thread(strip_first)
self.assertEqual(
_bytes('<ns0:root xmlns:ns0="myns" att="someval"/>'),
tostring(result))
def test_concurrent_attribute_names_in_dicts(self):
SubElement = self.etree.SubElement
names = list('abcdefghijklmnop')
runs_per_name = range(50)
result_matches = re.compile(
br'<thread_root>'
br'(?:<[a-p]{5} thread_attr_[a-p]="value" thread_attr2_[a-p]="value2"\s?/>)+'
br'</thread_root>').match
def testrun():
for _ in range(3):
root = self.etree.Element('thread_root')
for name in names:
tag_name = name * 5
new = []
for _ in runs_per_name:
el = SubElement(root, tag_name, {'thread_attr_' + name: 'value'})
new.append(el)
for el in new:
el.set('thread_attr2_' + name, 'value2')
s = etree.tostring(root)
self.assertTrue(result_matches(s))
# first, run only in sub-threads
self._run_threads(10, testrun)
# then, additionally include the main thread (and its parent dict)
self._run_threads(10, testrun, main_func=testrun)
def test_concurrent_proxies(self):
XML = self.etree.XML
root = XML(_bytes('<root><a>A</a><b xmlns="test">B</b><c/></root>'))
child_count = len(root)
def testrun():
for i in range(10000):
el = root[i%child_count]
del el
self._run_threads(10, testrun)
def test_concurrent_class_lookup(self):
XML = self.etree.XML
class TestElement(etree.ElementBase):
pass
class MyLookup(etree.CustomElementClassLookup):
repeat = range(100)
def lookup(self, t, d, ns, name):
count = 0
for i in self.repeat:
# allow other threads to run
count += 1
return TestElement
parser = self.etree.XMLParser()
parser.set_element_class_lookup(MyLookup())
root = XML(_bytes('<root><a>A</a><b xmlns="test">B</b><c/></root>'),
parser)
child_count = len(root)
def testrun():
for i in range(1000):
el = root[i%child_count]
del el
self._run_threads(10, testrun)
class ThreadPipelineTestCase(HelperTestCase):
"""Threading tests based on a thread worker pipeline.
"""
etree = etree
item_count = 40
class Worker(threading.Thread):
def __init__(self, in_queue, in_count, **kwargs):
threading.Thread.__init__(self)
self.in_queue = in_queue
self.in_count = in_count
self.out_queue = Queue(in_count)
self.__dict__.update(kwargs)
def run(self):
get, put = self.in_queue.get, self.out_queue.put
handle = self.handle
for _ in range(self.in_count):
put(handle(get()))
def handle(self, data):
raise NotImplementedError()
class ParseWorker(Worker):
def handle(self, xml, _fromstring=etree.fromstring):
return _fromstring(xml)
class RotateWorker(Worker):
def handle(self, element):
first = element[0]
element[:] = element[1:]
element.append(first)
return element
class ReverseWorker(Worker):
def handle(self, element):
element[:] = element[::-1]
return element
class ParseAndExtendWorker(Worker):
def handle(self, element, _fromstring=etree.fromstring):
element.extend(_fromstring(self.xml))
return element
class ParseAndInjectWorker(Worker):
def handle(self, element, _fromstring=etree.fromstring):
root = _fromstring(self.xml)
root.extend(element)
return root
class Validate(Worker):
def handle(self, element):
element.getroottree().docinfo.internalDTD.assertValid(element)
return element
class SerialiseWorker(Worker):
def handle(self, element):
return etree.tostring(element)
xml = (b'''\
<!DOCTYPE threadtest [
<!ELEMENT threadtest (thread-tag1,thread-tag2)+>
<!ATTLIST threadtest
version CDATA "1.0"
>
<!ELEMENT thread-tag1 EMPTY>
<!ELEMENT thread-tag2 (div)>
<!ELEMENT div (threaded)>
<!ATTLIST div
huhu CDATA #IMPLIED
>
<!ELEMENT threaded EMPTY>
<!ATTLIST threaded
host CDATA #REQUIRED
>
]>
<threadtest version="123">
''' + (b'''
<thread-tag1 />
<thread-tag2>
<div huhu="true">
<threaded host="here" />
</div>
</thread-tag2>
''') * 20 + b'''
</threadtest>''')
def _build_pipeline(self, item_count, *classes, **kwargs):
in_queue = Queue(item_count)
start = last = classes[0](in_queue, item_count, **kwargs)
start.setDaemon(True)
for worker_class in classes[1:]:
last = worker_class(last.out_queue, item_count, **kwargs)
last.setDaemon(True)
last.start()
return (in_queue, start, last)
def test_thread_pipeline_thread_parse(self):
item_count = self.item_count
xml = self.xml.replace(b'thread', b'THREAD') # use fresh tag names
# build and start the pipeline
in_queue, start, last = self._build_pipeline(
item_count,
self.ParseWorker,
self.RotateWorker,
self.ReverseWorker,
self.ParseAndExtendWorker,
self.Validate,
self.ParseAndInjectWorker,
self.SerialiseWorker,
xml=xml)
# fill the queue
put = start.in_queue.put
for _ in range(item_count):
put(xml)
# start the first thread and thus everything
start.start()
# make sure the last thread has terminated
last.join(60) # time out after 60 seconds
self.assertEqual(item_count, last.out_queue.qsize())
# read the results
get = last.out_queue.get
results = [get() for _ in range(item_count)]
comparison = results[0]
for i, result in enumerate(results[1:]):
self.assertEqual(comparison, result)
def test_thread_pipeline_global_parse(self):
item_count = self.item_count
xml = self.xml.replace(b'thread', b'GLOBAL') # use fresh tag names
XML = self.etree.XML
# build and start the pipeline
in_queue, start, last = self._build_pipeline(
item_count,
self.RotateWorker,
self.ReverseWorker,
self.ParseAndExtendWorker,
self.Validate,
self.SerialiseWorker,
xml=xml)
# fill the queue
put = start.in_queue.put
for _ in range(item_count):
put(XML(xml))
# start the first thread and thus everything
start.start()
# make sure the last thread has terminated
last.join(60) # time out after 90 seconds
self.assertEqual(item_count, last.out_queue.qsize())
# read the results
get = last.out_queue.get
results = [get() for _ in range(item_count)]
comparison = results[0]
for i, result in enumerate(results[1:]):
self.assertEqual(comparison, result)
def test_suite():
suite = unittest.TestSuite()
suite.addTests([unittest.makeSuite(ThreadingTestCase)])
suite.addTests([unittest.makeSuite(ThreadPipelineTestCase)])
return suite
if __name__ == '__main__':
print('to test use test.py %s' % __file__)