#!/usr/bin/env python3

import subprocess
import collections
import os
import sys
import re
import pprint


FNULL = open(os.devnull, "w")
pp = pprint.PrettyPrinter(indent=4, stream=sys.stderr)

DEBUG = os.environ.get("NM_FIND_BACKPORTS_DEBUG", None) == "1"


def dbg_log(s):
    if DEBUG:
        print(s, file=sys.stderr)


def dbg_pprint(obj):
    if DEBUG:
        pp.pprint(obj)


def print_err(s):
    print(s, file=sys.stderr)


def die(s):
    print_err(s)
    sys.exit(1)


def memoize(f):
    memo = {}

    def helper(x):
        if x not in memo:
            memo[x] = f(x)
        return memo[x]

    return helper


def re_bin(r):
    return r.encode("utf8")


def _keys_to_dict(itr):
    d = collections.OrderedDict()
    for c in itr:
        d[c] = None
    return d


@memoize
def git_ref_exists_full_path(ref):
    val = git_ref_exists(ref)
    if val:
        try:
            subprocess.check_output(["git", "show-ref", "-q", "--verify", str(ref)])
        except subprocess.CalledProcessError:
            pass
        else:
            return val
    return None


def _git_ref_exists_eval(ref):
    try:
        out = subprocess.check_output(
            ["git", "rev-parse", "--verify", str(ref) + "^{commit}"],
            stderr=FNULL,
        )
    except subprocess.CalledProcessError:
        return None
    o = out.decode("ascii").strip()
    if len(o) == 40:
        return o
    raise Exception(f"git-rev-parse for '{ref}' returned unexpected output {out}")


_git_ref_exists_cache = {}


def git_ref_exists(ref):
    val = _git_ref_exists_cache.get(ref, False)

    if val is False:
        val = _git_ref_exists_eval(ref)
        _git_ref_exists_cache[ref] = val
        if val and ref != val:
            _git_ref_exists_cache[val] = val

    return val


@memoize
def git_get_head_name(ref):
    out = subprocess.check_output(
        ["git", "rev-parse", "--symbolic-full-name", str(ref)], stderr=FNULL
    )
    return out.decode("utf-8").strip()


def git_merge_base(a, b):
    out = subprocess.check_output(["git", "merge-base", str(a), str(b)], stderr=FNULL)
    out = out.decode("ascii").strip()
    assert git_ref_exists(out)
    return out


def git_all_commits_grep(rnge, grep=None):
    if grep:
        grep = [("--grep=%s" % g) for g in grep]
        notes = ["-c", "notes.displayref=refs/notes/bugs"]
    else:
        grep = []
        notes = []
    out = subprocess.check_output(
        ["git"]
        + notes
        + ["log", "--pretty=%H", "--notes", "--reverse"]
        + grep
        + [str(rnge)],
        stderr=FNULL,
    )
    return [x for x in out.decode("ascii").split("\n") if x]


def git_logg(commits):
    commits = list(commits)
    if not commits:
        return ""
    out = subprocess.check_output(
        [
            "git",
            "log",
            "--no-show-signature",
            "--no-walk",
            "--pretty=format:%Cred%h%Creset - %Cgreen(%ci)%Creset [%C(yellow)%an%Creset] %s%C(yellow)%d%Creset",
            "--abbrev-commit",
            "--date=local",
        ]
        + [str(c) for c in commits],
        stderr=FNULL,
    )
    return out.decode("utf-8").strip()


@memoize
def git_all_commits(rnge):
    return git_all_commits_grep(rnge)


@memoize
def git_all_commits_set(rnge):
    return set(git_all_commits_grep(rnge))


def git_commit_sorted(commits):
    commits = list(commits)
    if not commits:
        return []
    out = subprocess.check_output(
        ["git", "log", "--no-walk", "--pretty=%H", "--reverse"]
        + [str(x) for x in commits],
        stderr=FNULL,
    )
    out = out.decode("ascii")
    return [x for x in out.split("\n") if x]


@memoize
def git_ref_commit_body(ref):
    return subprocess.check_output(
        [
            "git",
            "-c",
            "notes.displayref=refs/notes/bugs",
            "log",
            "-n1",
            "--pretty=%B%n%N",
            str(ref),
        ],
        stderr=FNULL,
    )


@memoize
def git_ref_commit_body_get_fixes(ref):
    body = git_ref_commit_body(ref)
    result = []
    for mo in re.finditer(re_bin("\\b[fF]ixes: *([0-9a-z]+)\\b"), body):
        c = mo.group(1).decode("ascii")
        h = git_ref_exists(c)
        if h:
            result.append(h)
    if result:
        # The commit that contains a "Fixes:" line, can also contain an "Ignore-Fixes:" line
        # to disable it. This only makes sense with refs/notes/bugs notes, to fix up a wrong
        # annotation.
        for mo in re.finditer(re_bin("\\bIgnore-[fF]ixes: *([0-9a-z]+)\\b"), body):
            c = mo.group(1).decode("ascii")
            h = git_ref_exists(c)
            try:
                result.remove(h)
            except ValueError:
                pass

    return result


@memoize
def git_ref_commit_body_get_cherry_picked_one(ref):
    ref = git_ref_exists(ref)
    if not ref:
        return None
    body = git_ref_commit_body(ref)
    result = None
    for r in [
        re_bin("\\(cherry picked from commit ([0-9a-z]+)\\)"),
        re_bin("\\bIgnore-Backport: *([0-9a-z]+)\\b"),
    ]:
        for mo in re.finditer(r, body):
            c = mo.group(1).decode("ascii")
            h = git_ref_exists(c)
            if h:
                if not result:
                    result = [h]
                else:
                    result.append(h)
    return result


@memoize
def git_ref_commit_body_get_cherry_picked_recurse(ref):
    ref = git_ref_exists(ref)
    if not ref:
        return None

    def do_recurse(result, ref):
        result2 = git_ref_commit_body_get_cherry_picked_one(ref)
        if result2:
            extra = [h2 for h2 in result2 if h2 not in result]
            if extra:
                result.extend(extra)
                for h2 in extra:
                    do_recurse(result, h2)

    result = []
    do_recurse(result, ref)
    return result


def git_commits_annotate_fixes(rnge):
    commits = git_all_commits(rnge)
    c_dict = _keys_to_dict(commits)
    for c in git_all_commits_grep(rnge, grep=["[Ff]ixes:"]):
        ff = git_ref_commit_body_get_fixes(c)
        if ff:
            c_dict[c] = ff
    return c_dict


def git_commits_annotate_cherry_picked(rnge):
    commits = git_all_commits(rnge)
    c_dict = _keys_to_dict(commits)
    for c in git_all_commits_grep(
        ref_head, grep=["cherry picked from commit", "Ignore-Backport:"]
    ):
        ff = git_ref_commit_body_get_cherry_picked_recurse(c)
        if ff:
            c_dict[c] = ff
    return c_dict


def git_ref_in_history(ref, rnge):
    return git_ref_exists(ref) in git_all_commits_set(rnge)


if __name__ == "__main__":
    if len(sys.argv) <= 1:
        ref_head0 = "HEAD"
    else:
        ref_head0 = sys.argv[1]

    ref_head = git_ref_exists(ref_head0)
    if not ref_head:
        die('Ref "%s" does not exist' % (ref_head0))

    if not git_ref_exists_full_path("refs/notes/bugs"):
        die(
            "Notes refs/notes/bugs not found. Read CONTRIBUTING.md file for how to setup the notes"
        )

    ref_upstreams = []
    if len(sys.argv) <= 2:
        head_name = git_get_head_name(ref_head0)
        match = False
        if head_name:
            match = re.match("^refs/(heads|remotes/[^/]*)/nm-1-([0-9]+)$", head_name)
        if match:
            i = int(match.group(2))
            while True:
                i += 2
                r = "nm-1-" + str(i)
                if not git_ref_exists(r):
                    r = "refs/remotes/origin/nm-1-" + str(i)
                    if not git_ref_exists(r):
                        break
                ref_upstreams.append(r)
            ref_upstreams.append("main")

    if not ref_upstreams:
        if len(sys.argv) <= 2:
            ref_upstreams = ["main"]
        else:
            ref_upstreams = list(sys.argv[2:])

    for h in ref_upstreams:
        if not git_ref_exists(h):
            die('Upstream ref "%s" does not exist' % (h))

    print_err("Check %s (%s)" % (ref_head0, ref_head))
    print_err("Upstream refs: %s" % (ref_upstreams))

    print_err('Check patches of "%s"...' % (ref_head))
    own_commits_list = git_all_commits(ref_head)
    own_commits_cherry_picked = git_commits_annotate_cherry_picked(ref_head)

    cherry_picks_all = collections.OrderedDict()
    for c, cherry_picked in own_commits_cherry_picked.items():
        if cherry_picked:
            for c2 in cherry_picked:
                l = cherry_picks_all.get(c2)
                if not l:
                    cherry_picks_all[c2] = [c]
                else:
                    l.append(c)

    own_commits_cherry_picked_flat = set()
    for c, p in own_commits_cherry_picked.items():
        own_commits_cherry_picked_flat.add(c)
        if p:
            own_commits_cherry_picked_flat.update(p)

    dbg_log(">>> own_commits_cherry_picked")
    dbg_pprint(own_commits_cherry_picked)

    dbg_log(">>> cherry_picks_all")
    dbg_pprint(cherry_picks_all)

    # find all commits on the upstream branches that fix another commit.
    fixing_commits = {}
    for ref_upstream in ref_upstreams:
        ref_str = ref_head + ".." + ref_upstream
        print_err(f'Check upstream patches "{ref_str}"...')
        for c, fixes in git_commits_annotate_fixes(ref_str).items():
            if not fixes:
                dbg_log(f">>> test {c} : SKIP (does not fix anything)")
                continue
            if c in cherry_picks_all:
                # commit 'c' is already backported. Skip it.
                dbg_log(f">>> test {c} => {fixes} : SKIP (already backported)")
                continue
            dbg_log(f">>> test {c} => {fixes} : process")
            for f in fixes:
                if f not in own_commits_cherry_picked_flat:
                    # commit "c" fixes commit "f", but this is not one of our own commits
                    # and not interesting.
                    dbg_log(f">>> fixes {f} not in own_commits_cherry_picked")
                    continue
                dbg_log(f">>> take {c} (fixes {fixes})")
                fixing_commits[c] = fixes
                break

    extra = collections.OrderedDict(
        [(c, git_ref_commit_body_get_cherry_picked_recurse(c)) for c in fixing_commits]
    )
    extra2 = []
    for c in extra:
        is_back = False
        for e_v in extra.values():
            if c in e_v:
                is_back = True
                break
        if not is_back:
            extra2.append(c)

    commits_good = extra2

    commits_good = git_commit_sorted(commits_good)

    print_err(git_logg(commits_good))

    not_in = [
        c
        for c in commits_good
        if not git_ref_in_history(c, f"{ref_head}..{ref_upstreams[0]}")
    ]
    if not_in:
        print_err("")
        print_err(
            f'WARNING: The following commits are not from the first reference "{ref_upstreams[0]}".'
        )
        print_err(
            f'  You may want to first backports those patches to "{ref_upstreams[0]}".'
        )
        for l in git_logg(git_commit_sorted(not_in)).splitlines():
            print_err(f"  - {l}")
        print_err("")

    for c in reversed(commits_good):
        print("%s" % (c))
