summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--b4/__init__.py22
1 files changed, 13 insertions, 9 deletions
diff --git a/b4/__init__.py b/b4/__init__.py
index 483173c..d69c301 100644
--- a/b4/__init__.py
+++ b/b4/__init__.py
@@ -2110,18 +2110,20 @@ def get_msgid(cmdargs) -> Optional[str]:
def get_strict_thread(msgs, msgid):
want = {msgid}
- got = set()
seen = set()
maybe = dict()
- strict = list()
+ strict = dict()
while True:
for msg in msgs:
c_msgid = LoreMessage.get_clean_msgid(msg)
seen.add(c_msgid)
- if c_msgid in got:
+ if c_msgid in strict.keys():
+ # Check if the duplicate message passes DKIM validation
+ if not strict[c_msgid][0] and can_dkim and dkim.verify(msg.as_bytes(), logger=dkimlogger):
+ logger.debug('DKIM-validating message used for %s', msgid)
+ strict[c_msgid] = (True, msg)
continue
logger.debug('Looking at: %s', c_msgid)
-
refs = set()
msgrefs = list()
if msg.get('In-Reply-To', None):
@@ -2129,7 +2131,7 @@ def get_strict_thread(msgs, msgid):
if msg.get('References', None):
msgrefs += email.utils.getaddresses([str(x) for x in msg.get_all('references', [])])
for ref in set([x[1] for x in msgrefs]):
- if ref in got or ref in want:
+ if ref in strict.keys() or ref in want:
want.add(c_msgid)
elif len(ref):
refs.add(ref)
@@ -2140,8 +2142,10 @@ def get_strict_thread(msgs, msgid):
maybe[ref].add(c_msgid)
if c_msgid in want:
- strict.append(msg)
- got.add(c_msgid)
+ dkimres = None
+ if can_dkim:
+ dkimres = dkim.verify(msg.as_bytes(), logger=dkimlogger)
+ strict[c_msgid] = (dkimres, msg)
want.update(refs)
want.discard(c_msgid)
logger.debug('Kept in thread: %s', c_msgid)
@@ -2157,7 +2161,7 @@ def get_strict_thread(msgs, msgid):
# Remove any entries not in "seen" (missing messages)
for c_msgid in set(want):
- if c_msgid not in seen or c_msgid in got:
+ if c_msgid not in seen or c_msgid in strict.keys():
want.remove(c_msgid)
if not len(want):
break
@@ -2168,7 +2172,7 @@ def get_strict_thread(msgs, msgid):
if len(msgs) > len(strict):
logger.debug('Reduced mbox to strict matches only (%s->%s)', len(msgs), len(strict))
- return strict
+ return [x[1] for x in strict.values()]
def mailsplit_bytes(bmbox: bytes, outdir: str) -> list: