# An implementation of the product monomial crystal in any Dynkin type.
# Author: Joel Gibson, 2018 (http://www.maths.usyd.edu.au/u/joelg/index.html)

import bisect
import itertools

from collections import OrderedDict, namedtuple, Counter

def flatten_lists(t):
    """
    >>> list(flatten_lists([[(3, 2), []], [], [[[(4, 5)]]]]))
    [(3, 2), (4, 5)]
    """
    for x in t:
        if not isinstance(x, list):
            yield x
        else:
            yield from flatten_lists(x)

class Vect:
    """A immutable data structure which represents an infinite vector (of finite
    support) of integers, indexed by the integers. It supports operations
    for finding semi-infinite partial sums in either direction, and finding the
    points maximising these semi-infinite partial sums."""

    # The underlying data structure is a pair of lists of integers of the same length,
    # one of keys, and one of multiplicities. The keys must be sorted, and the
    # multiplicities should never be zero.
    
    def __init__(self, keys, vals):
        """Create a vector containing the sorted list of keys and values.
        Values should contain only nonzero entries."""

        # assert len(keys) == len(vals)
        # assert not any(val == 0 for val in vals)
        # assert all(keys[i] < keys[i+1] for i in range(0, len(keys) - 1))

        self._keys = keys
        self._vals = vals

    @staticmethod
    def D(d):
        """Create a Vect from a dictionary of int -> int.
        >>> Vect.D({1: 1, 2: 0, 3: -1})
        Vect.D({1: 1, 3: -1})
        """
        pairs = sorted((k, v) for k, v in d.items() if v != 0)
        return Vect([k for k, v in pairs], [v for k, v in pairs])

    @staticmethod
    def zero():
        """
        >>> Vect.zero()
        Vect.D({})
        """
        return Vect([], [])

    @property
    def dict(self):
        return {k: v for k, v in zip(self._keys, self._vals)}

    def __repr__(self):
        return f"Vect.D({self.dict})"

    def is_zero(self):
        return len(self._keys) == 0

    def sum(self):
        return sum(self._vals)

    def add(self, key, val):
        """Return a new vector, by adding in one place to the old one.
        >>> Vect.D({1: 1, 2: 1}).add(2, -1)
        Vect.D({1: 1})
        >>> Vect.zero().add(2, -1)
        Vect.D({2: -1})
        """

        # Locate where key is in our list of keys
        p = bisect.bisect_left(self._keys, key)

        # If the key belongs to the list
        if p < len(self._keys) and self._keys[p] == key:
            new_val = self._vals[p] + val
            if new_val == 0:
                return Vect(
                    self._keys[:p] + self._keys[p+1:],
                    self._vals[:p] + self._vals[p+1:])

            new_vals = self._vals[:]
            new_vals[p] = new_val
            return Vect(self._keys, new_vals)

        # If the key is new
        if val == 0:
            return self

        return Vect(
            self._keys[:p] + [key] + self._keys[p:],
            self._vals[:p] + [val] + self._vals[p:])
        

    def __add__(self, other):
        """Add two vectors.
        
        >>> Vect.D({1: 1, 2: 1, 4: 5}) + Vect.D({2: -1, 3: 3, 4: 4})
        Vect.D({1: 1, 3: 3, 4: 9})
        """
        assert isinstance(other, Vect)
        
        keys, vals = [], []
        i, j = 0, 0
        while i < len(self._keys) and j < len(other._keys):
            if self._keys[i] < other._keys[j]:
                keys += [self._keys[i]]
                vals += [self._vals[i]]
                i += 1
            elif self._keys[i] == other._keys[j]:
                val = self._vals[i] + other._vals[j]
                if val != 0:
                    keys += [self._keys[i]]
                    vals += [self._vals[i] + other._vals[j]]
                i += 1
                j += 1
            else:
                keys += [other._keys[j]]
                vals += [other._vals[j]]
                j += 1

        if i < len(self._keys):
            keys += self._keys[i:]
            vals += self._vals[i:]

        if j < len(other._keys):
            keys += other._keys[j:]
            vals += other._vals[j:]

        return Vect(keys, vals)
                
        
    def pairs(self):
        return zip(self._keys, self._vals)

    def upper_partial(self, k):
        """Return v[k] + v[k+1] + v[k+2] + ...

        >>> v = Vect.D({1: 1, 2: 2, 3: 3})
        >>> v.upper_partial(0)
        6
        >>> v.upper_partial(1)
        6
        >>> v.upper_partial(2)
        5
        >>> v.upper_partial(3)
        3
        >>> v.upper_partial(4)
        0
        """

        # p will satisfy keys[:p] < k <= keys[p:]
        p = bisect.bisect_left(self._keys, k)
        return sum(self._vals[i] for i in range(p, len(self._keys)))

    def max_upper_partial(self):
        """Find the largest positition k maximising the upper partial sum.
        Returns (k, upper_partial(k)).

        >>> Vect.D({-3: 1, -1: -2, 1: 2, 3: -1, 4: 1}).max_upper_partial()
        (1, 2)
        """

        assert len(self._keys) != 0

        max_key, max_val = self._keys[-1], self._vals[-1]
        cumulative = self._vals[-1]
        
        for i in range(len(self._keys) - 2, -1, -1):
            cumulative += self._vals[i]
            if cumulative > max_val:
                max_key, max_val = self._keys[i], cumulative

        return max_key, max_val

    
    def min_lower_partial(self):
        """Find the lowest positition k minimising the lower partial sum.
        Returns (k, lower_partial(k)).

        >>> Vect.D({-3: 1, -1: -2, 1: 2, 3: -1, 4: 1}).min_lower_partial()
        (-1, -1)
        """

        assert len(self._keys) != 0

        min_key, min_val = self._keys[0], self._vals[0]
        cumulative = self._vals[0]
        
        for i in range(1, len(self._keys)):
            cumulative += self._vals[i]
            if cumulative < min_val:
                min_key, min_val = self._keys[i], cumulative

        return min_key, min_val

    
    def lower_partial(self, k):
        """Return ... + v[k-2] + v[k-1] + v[k]

        >>> v = Vect.D({1: 1, 2: 2, 3: 3})
        >>> v.lower_partial(0)
        0
        >>> v.lower_partial(2)
        3
        >>> v.lower_partial(3)
        6
        >>> v.lower_partial(4)
        6
        """

        # p will satisfy keys[:p] <= k < keys[p:]
        p = bisect.bisect_right(self._keys, k)
        return sum(self._vals[i] for i in range(p))
        
    
# A simply-laced Dynkin diagram is a mapping of int -> [int], which should be the
# adjacency list of a simple graph which is a connected tree. Nodes are numbered from
# zero, and go up to n, where n is the parameter.
def order_dict(d):
    return OrderedDict(sorted(d.items()))


def A(n):
    """
    >>> A(1) == {1: set()}
    True
    >>> A(2) == {1: {2}, 2: {1}}
    True
    >>> A(3) == {1: {2}, 2: {1, 3}, 3: {2}}
    True
    """
    assert n >= 1

    node_set = set(range(1, n+1))
    return order_dict({i: {i-1, i+1} & node_set for i in range(1, n+1)})


def D(n):
    """
    >>> D(3) == {1: {2, 3}, 2: {1}, 3: {1}}
    True
    >>> D(4) == {1: {2}, 2: {1, 3, 4}, 3: {2}, 4: {2}}
    True
    """

    assert n >= 3

    node_set = set(range(1, n+1))
    graph = {i: {i-1, i+1} & node_set for i in range(1, n-2)}
    graph[n-2] = {n-3, n-1, n} & node_set
    graph[n-1] = {n-2}
    graph[n] = {n-2}

    return order_dict(graph)

def E(n):
    E_graphs = {
        6: order_dict({1: {3}, 2: {4}, 3: {1, 4}, 4: {3, 2, 5}, 5: {4, 6}, 6: {5}}),
        7: order_dict({1: {3}, 2: {4}, 3: {1, 4}, 4: {3, 2, 5}, 5: {4, 6}, 6: {5, 7}, 7: {6}}),
        8: order_dict({1: {3}, 2: {4}, 3: {1, 4}, 4: {3, 2, 5}, 5: {4, 6}, 6: {5, 7}, 7:{6, 8}, 8: {7}})}

    assert n in E_graphs

    return E_graphs[n]


class Monomial:
    def __init__(self, dynkin, vects, sset=[]):
        """Creates a new monomial. vects should be a map of node -> Vect for nodes
        of the Dynkin diagram. Sset is some kind of very nested series of lists which
        stores the S-set. Storing them in this kind of linked-list tree thing means that
        I just store a new link in a chain, rather than copying a whole new list each
        time."""

        self._dynkin = dynkin
        self._vects = vects
        self._tuple = tuple(tuple(vects[i].pairs()) for i in dynkin)
        self._sset = sset

    def __repr__(self):
        return f"Monomial({ {node: self._vects[node].dict for node in self._dynkin} })"

    def __eq__(self, other):
        return self._tuple == other._tuple

    def __hash__(self):
        return hash(self._tuple)

    def weight(self):
        return tuple(self._vects[node].sum() for node in self._dynkin)

    def is_one(self):
        return all(vect.is_zero() for vect in self._vects)

    def is_hw(self):
        """
        >>> [x for x in product_crystal(D(4), {1:[1]}) if x.is_hw()]
        [Monomial({1: {1: 1}, 2: {}, 3: {}, 4: {}})]
        """
        for node in self._dynkin:
            if self._vects[node].is_zero():
                continue

            _, val = self._vects[node].min_lower_partial()
            if val < 0:
                return False

        return True

    def is_lw(self):
        """
        >>> [x for x in product_crystal(D(4), {1:[1]}) if x.is_lw()]
        [Monomial({1: {-5: -1}, 2: {}, 3: {}, 4: {}})]
        """
        for node in self._dynkin:
            if self._vects[node].is_zero():
                continue

            _, val = self._vects[node].max_upper_partial()
            if val > 0:
                return False

        return True
        

    def f(self, node):
        """
        >>> Monomial.Y(A(1), 1, 1).f(1)
        Monomial({1: {-1: -1}})
        >>> Monomial.Y(A(3), 1, 1).f(1)
        Monomial({1: {-1: -1}, 2: {0: 1}, 3: {}})
        >>> Monomial.Y(A(3), 1, 1).f(1).f(2)
        Monomial({1: {}, 2: {-2: -1}, 3: {-1: 1}})
        """
        assert node in self._dynkin

        if self._vects[node].is_zero():
            return None

        k, val = self._vects[node].max_upper_partial()
        if val <= 0:
            return None

        new_vects = {k: v for k, v in self._vects.items()}

        
        new_vects[node] = new_vects[node].add(k, -1).add(k-2, -1)
        for neigh in self._dynkin[node]:
            new_vects[neigh] = new_vects[neigh].add(k-1, 1)

        return Monomial(self._dynkin, new_vects, sset=[(node, k-1), self._sset])

    def __mul__(self, other):
        """
        >>> Monomial.Y(A(3), 1, -1) * Monomial.Y(A(3), 2, 8)
        Monomial({1: {-1: 1}, 2: {8: 1}, 3: {}})
        """
        assert isinstance(other, Monomial)

        return Monomial(self._dynkin, {node: self._vects[node] + other._vects[node] for node in self._dynkin}, sset=[self._sset, other._sset])

    def sset(self):
        """Return the S-multiset for this node, as a Counter."""
        return Counter(flatten_lists(self._sset))
    

    @staticmethod
    def Y(dynkin, i, c):
        """
        >>> Monomial.Y(A(2), 2, -100)
        Monomial({1: {}, 2: {-100: 1}})
        """
        assert i in dynkin
        return Monomial(dynkin, {node: Vect.D({c: 1}) if node == i else Vect.zero() for node in dynkin})


    @staticmethod
    def One(dynkin):
        """
        >>> Monomial.One(A(3))
        Monomial({1: {}, 2: {}, 3: {}})
        """
        return Monomial(dynkin, {node: Vect.zero() for node in dynkin})


def bipartition(dynkin):
    """Produce one side of a bipartition of the given Dynkin diagram.

    >>> bipartition(A(1))
    {1}
    >>> bipartition(A(5))
    {1, 3, 5}
    >>> bipartition(D(3))
    {1}
    >>> bipartition(D(4))
    {1, 3, 4}
    """

    # Run a BFS starting from 1, and then union up all of the even-numbered levels.
    
    assert 1 in dynkin
    levels = [{1}]
    seen = {1}
    while len(seen) < len(dynkin):
        last_level = levels[-1]
        new_level = {neigh for node in last_level for neigh in dynkin[node]} - seen
        seen |= new_level
        levels += [new_level]

    return {node for i in range(0, len(levels), 2) for node in levels[i]}

def check_parity(dynkin, R):
    """Check if there is a bipartition of the Dynkin diagram such that R has
    consistent parity.

    >>> check_parity(A(3), {1:[1], 2:[2], 3:[3]})
    True
    >>> check_parity(A(3), {1:[0], 2:[1], 3:[2]})
    True
    >>> check_parity(A(6), {1:[1], 6:[1]})
    False
    >>> check_parity(D(4), {1:[0], 2:[1], 3:[0], 4:[0]})
    True
    """

    # There should only be two possible bipartitions of each dynkin diagram for types
    # A, D, E, and so we will just check whether either of these work.
    partition = bipartition(dynkin)

    parity_1 = {node: 0 if node in partition else 1 for node in dynkin}
    parity_2 = {node: 1 if node in partition else 0 for node in dynkin}

    if all(c % 2 == parity_1[node] for node in R for c in R[node]):
        return True
    if all(c % 2 == parity_2[node] for node in R for c in R[node]):
        return True

    return False


def generate(dynkin, monomials):
    """Find the closure under the f_i operators of the given monomials.

    >>> len(generate(A(4), {Monomial.Y(A(4), 1, 1)}))
    5
    >>> len(generate(D(5), {Monomial.Y(D(5), 1, 1)}))
    10
    """
    level = set(monomials)
    seen = set(monomials)
    
    while level:
        new_level = set()
        for monomial in level:
            for node in dynkin:
                neigh = monomial.f(node)
                if neigh and neigh not in seen:
                    new_level |= {neigh}
                    seen |= {neigh}

        level = new_level

    return seen


def fundamental(dynkin, node, c):
    """
    >>> len(fundamental(A(5), 3, -100)) # 6 choose 3
    20
    """
    assert node in dynkin

    return generate(dynkin, {Monomial.Y(dynkin, node, c)})


def product_crystal(dynkin, R):
    """
    >>> sorted([elem.weight() for elem in product_crystal(A(4), {1:[1], 3:[101]}) if elem.is_hw()])
    [(0, 0, 0, 1), (1, 0, 1, 0)]
    """
    assert all(node in dynkin for node in R), "There are nodes not belonging to the Dynkin diagram"
    assert check_parity(dynkin, R), "R has inconsistent parity"

    # Try to sort (i, c) pairs so that c's which are close are generated
    # close together, which should maximise cancellation, and keep the working
    # set as small as possible for as long as possible.
    order = sorted([(i, c) for i in R for c in R[i]], key=lambda x: x[1])

    elements = {Monomial.One(dynkin)}
    for node, c in order:
        fund = fundamental(dynkin, node, c)
        elements = {elem1 * elem2 for elem1 in elements for elem2 in fund}

    return elements


# Take a set of monomials, and return a sorted list of tuples
# of (weight, multiplicity) for each occurring highest weight.
def weight_data(monomials):
    mults = Counter(mono.weight() for mono in monomials if mono.is_hw())
    return sorted(mults.items())


# Print out data for the given crystal set. Here n means A_n, and the monomial_set
# is the underlying set of a crystal.
def show_crystal_data(name, dynkin, R, underlying_set):
    # Create the dynkin diagram, and lambda.
    lamb = [len(R.get(i, [])) for i in dynkin]
    hws = weight_data(underlying_set)
    number_irreps = sum(mult for weight, mult in hws)

    header = f"""Product monomial crystal in type {name},
with parameters lambda = {lamb}, R = {R}

Number of monomials (dimension):    {len(underlying_set)}
Number of connected components:     {number_irreps}
Number of isoclasses of components: {len(hws)}"""
    print(header)
    print("Listing of isoclasses of highest weights:")
    print("  %25s  %4s" % ("Weight", "Mult"))
    for weight, mult in hws:
        print("  %25r  %4d" % (weight, mult))
    print()

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description="Generates the product monomial crystal.")
    parser.add_argument("type", choices=['A', 'D', 'E'], help="the Lie algebra type")
    parser.add_argument("n", type=int, help="number of nodes in the Dynkin diagram")
    parser.add_argument("R", help="the parameter multiset")
    parser.add_argument("--show_hws", help="Show pictures of the highest-weight elements", action="store_true")
    parser.add_argument("--show_lws", help="Show pictures of the lowest-weight elements", action="store_true")
    args = parser.parse_args()

    dynkins = {
        'A': A,
        'D': D,
        'E': E
    }

    if args.type not in dynkins:
        parser.error(message=f"Unrecognised dynkin type {args.type}")

    dynkin_str = args.type + str(args.n)
    dynkin = dynkins[args.type](args.n)
    R = eval(args.R)

    product = product_crystal(dynkin, R)
    show_crystal_data(dynkin_str, dynkin, R, product)

    if args.show_hws:
        hw_elems = sorted([x for x in product if x.is_hw()], key = lambda x: x.weight())
        for x in hw_elems:
            print('----------')
            print(f'Shape of highest-weight elem with weight {x.weight()}')
            sset = x.sset()
            rows = {k for i, k in sset}
            if not rows:
                print('empty')
                continue
            digits = max(map(len, map(str, rows)))
            for k in reversed(range(min(rows), max(rows) + 1)):
                print(f"{k:>{digits}}: " + "".join(" " if (i, k) not in sset else str(sset[(i, k)]) for i in dynkin))

    if args.show_lws:
        lw_elems = sorted([x for x in product if x.is_lw()], key = lambda x: x.weight())
        for x in lw_elems:
            print('----------')
            print(f'Shape of lowest-weight elem with weight {x.weight()}')
            sset = x.sset()
            rows = {k for i, k in sset}
            digits = max(map(len, map(str, rows)))
            for k in reversed(range(min(rows), max(rows) + 1)):
                print(f"{k:>{digits}}: " + "".join(" " if (i, k) not in sset else str(sset[(i, k)]) for i in dynkin))
