diff --git a/src/etc/htmldocck.py b/src/etc/htmldocck.py index 6bb235b2c83..df215f31823 100644 --- a/src/etc/htmldocck.py +++ b/src/etc/htmldocck.py @@ -285,6 +285,11 @@ def flatten(node): return ''.join(acc) +def make_xml(text): + xml = ET.XML('%s' % text) + return xml + + def normalize_xpath(path): path = path.replace("{{channel}}", channel) if path.startswith('//'): @@ -401,7 +406,7 @@ def get_tree_count(tree, path): return len(tree.findall(path)) -def check_snapshot(snapshot_name, tree, normalize_to_text): +def check_snapshot(snapshot_name, actual_tree, normalize_to_text): assert rust_test_path.endswith('.rs') snapshot_path = '{}.{}.{}'.format(rust_test_path[:-3], snapshot_name, 'html') try: @@ -414,11 +419,15 @@ def check_snapshot(snapshot_name, tree, normalize_to_text): raise FailedCheck('No saved snapshot value') if not normalize_to_text: - actual_str = ET.tostring(tree).decode('utf-8') + actual_str = ET.tostring(actual_tree).decode('utf-8') else: - actual_str = flatten(tree) + actual_str = flatten(actual_tree) + + if not expected_str \ + or (not normalize_to_text and + not compare_tree(make_xml(actual_str), make_xml(expected_str), stderr)) \ + or (normalize_to_text and actual_str != expected_str): - if expected_str != actual_str: if bless: with open(snapshot_path, 'w') as snapshot_file: snapshot_file.write(actual_str) @@ -430,6 +439,59 @@ def check_snapshot(snapshot_name, tree, normalize_to_text): print() raise FailedCheck('Actual snapshot value is different than expected') + +# Adapted from https://github.com/formencode/formencode/blob/3a1ba9de2fdd494dd945510a4568a3afeddb0b2e/formencode/doctest_xml_compare.py#L72-L120 +def compare_tree(x1, x2, reporter=None): + if x1.tag != x2.tag: + if reporter: + reporter('Tags do not match: %s and %s' % (x1.tag, x2.tag)) + return False + for name, value in x1.attrib.items(): + if x2.attrib.get(name) != value: + if reporter: + reporter('Attributes do not match: %s=%r, %s=%r' + % (name, value, name, x2.attrib.get(name))) + return False + for name in x2.attrib: + if name not in x1.attrib: + if reporter: + reporter('x2 has an attribute x1 is missing: %s' + % name) + return False + if not text_compare(x1.text, x2.text): + if reporter: + reporter('text: %r != %r' % (x1.text, x2.text)) + return False + if not text_compare(x1.tail, x2.tail): + if reporter: + reporter('tail: %r != %r' % (x1.tail, x2.tail)) + return False + cl1 = list(x1) + cl2 = list(x2) + if len(cl1) != len(cl2): + if reporter: + reporter('children length differs, %i != %i' + % (len(cl1), len(cl2))) + return False + i = 0 + for c1, c2 in zip(cl1, cl2): + i += 1 + if not compare_tree(c1, c2, reporter=reporter): + if reporter: + reporter('children %i do not match: %s' + % (i, c1.tag)) + return False + return True + + +def text_compare(t1, t2): + if not t1 and not t2: + return True + if t1 == '*' or t2 == '*': + return True + return (t1 or '').strip() == (t2 or '').strip() + + def stderr(*args): if sys.version_info.major < 3: file = codecs.getwriter('utf-8')(sys.stderr)