From 3e96a43accac28586e34a1a2b4a3c90948281bd0 Mon Sep 17 00:00:00 2001 From: Konstantin Ryabitsev Date: Tue, 17 Aug 2021 16:55:01 -0400 Subject: Move dedupe code into central location We want to dedupe all threads we retrieve from public-inbox, so do this in the central place instead of only when doing get_strict_tread(). Signed-off-by: Konstantin Ryabitsev --- b4/__init__.py | 64 +++++++++++++++++++++++++++++++++------------------------- 1 file changed, 36 insertions(+), 28 deletions(-) diff --git a/b4/__init__.py b/b4/__init__.py index 7e102ea..fd9979a 100644 --- a/b4/__init__.py +++ b/b4/__init__.py @@ -2145,18 +2145,18 @@ def get_msgid(cmdargs) -> Optional[str]: def get_strict_thread(msgs, msgid): want = {msgid} + got = set() seen = set() maybe = dict() - strict = dict() + strict = list() while True: for msg in msgs: c_msgid = LoreMessage.get_clean_msgid(msg) seen.add(c_msgid) - if c_msgid in strict.keys(): - logger.debug('Picked a more preferred source for %s', msgid) - strict[c_msgid] = LoreMessage.get_preferred_duplicate(strict[c_msgid], msg) + if c_msgid in got: continue logger.debug('Looking at: %s', c_msgid) + refs = set() msgrefs = list() if msg.get('In-Reply-To', None): @@ -2164,7 +2164,7 @@ def get_strict_thread(msgs, msgid): if msg.get('References', None): msgrefs += email.utils.getaddresses([str(x) for x in msg.get_all('references', [])]) for ref in set([x[1] for x in msgrefs]): - if ref in strict.keys() or ref in want: + if ref in got or ref in want: want.add(c_msgid) elif len(ref): refs.add(ref) @@ -2175,7 +2175,8 @@ def get_strict_thread(msgs, msgid): maybe[ref].add(c_msgid) if c_msgid in want: - strict[c_msgid] = msg + strict.append(msg) + got.add(c_msgid) want.update(refs) want.discard(c_msgid) logger.debug('Kept in thread: %s', c_msgid) @@ -2191,7 +2192,7 @@ def get_strict_thread(msgs, msgid): # Remove any entries not in "seen" (missing messages) for c_msgid in set(want): - if c_msgid not in seen or c_msgid in strict.keys(): + if c_msgid not in seen or c_msgid in got: want.remove(c_msgid) if not len(want): break @@ -2202,7 +2203,7 @@ def get_strict_thread(msgs, msgid): if len(msgs) > len(strict): logger.debug('Reduced mbox to strict matches only (%s->%s)', len(msgs), len(strict)) - return strict.values() + return strict def mailsplit_bytes(bmbox: bytes, outdir: str) -> list: @@ -2228,26 +2229,33 @@ def get_pi_thread_by_url(t_mbx_url, nocache=False): for msg in os.listdir(cachedir): with open(os.path.join(cachedir, msg), 'rb') as fh: msgs.append(email.message_from_binary_file(fh)) - return msgs - - logger.critical('Grabbing thread from %s', t_mbx_url.split('://')[1]) - session = get_requests_session() - resp = session.get(t_mbx_url) - if resp.status_code != 200: - logger.critical('Server returned an error: %s', resp.status_code) - return None - t_mbox = gzip.decompress(resp.content) - resp.close() - if not len(t_mbox): - logger.critical('No messages found for that query') - return None - # Convert into individual files using git-mailsplit - with tempfile.TemporaryDirectory(suffix='-mailsplit') as tfd: - msgs = mailsplit_bytes(t_mbox, tfd) - if os.path.exists(cachedir): - shutil.rmtree(cachedir) - shutil.copytree(tfd, cachedir) - return msgs + else: + logger.critical('Grabbing thread from %s', t_mbx_url.split('://')[1]) + session = get_requests_session() + resp = session.get(t_mbx_url) + if resp.status_code != 200: + logger.critical('Server returned an error: %s', resp.status_code) + return None + t_mbox = gzip.decompress(resp.content) + resp.close() + if not len(t_mbox): + logger.critical('No messages found for that query') + return None + # Convert into individual files using git-mailsplit + with tempfile.TemporaryDirectory(suffix='-mailsplit') as tfd: + msgs = mailsplit_bytes(t_mbox, tfd) + if os.path.exists(cachedir): + shutil.rmtree(cachedir) + shutil.copytree(tfd, cachedir) + + deduped = dict() + for msg in msgs: + msgid = LoreMessage.get_clean_msgid(msg) + if msgid in deduped: + deduped[msgid] = LoreMessage.get_preferred_duplicate(deduped[msgid], msg) + continue + deduped[msgid] = msg + return list(deduped.values()) def get_pi_thread_by_msgid(msgid, useproject=None, nocache=False, onlymsgids: Optional[set] = None): -- cgit v1.2.3