import json
import math
import random

from tour import Tour


def euclidean_distance(pt1, pt2):
    return math.sqrt((pt1[0] - pt2[0]) ** 2 + (pt1[1] - pt2[1]) ** 2)


def generate_matchings(L):
    if len(L) == 0:
        yield []

    for (index, i2) in enumerate(L[1:]):
        for other_part in generate_matchings(
            [L[ind] for ind in range(1, len(L)) if ind != index + 1]
        ):
            yield [(L[0], i2)] + other_part


def one_random_matching(L):
    match = []
    L = list(L)

    while len(L) > 0:
        ind1 = random.randint(0, len(L) - 1)
        pt1 = L.pop(ind1)
        ind2 = random.randint(0, len(L) - 1)
        pt2 = L.pop(ind2)
        match.append([pt1, pt2])

    return match



class TSP:
    def __init__(self, cities):
        self.cities = list(map(tuple, cities))

    @staticmethod
    def random_cities(N):
        return [
            (round(random.random(), 2), round(random.random(), 2)) for _ in range(N)
        ]

    @staticmethod
    def from_file(file_name):
        return TSP(json.loads(open(file_name, "r").read()))

    def score(self, tour):
        return sum(euclidean_distance(v1, v2) for (v1, v2) in tour.edges.items())

    def tweak(self):
        pass

    def random_tour(self):
        vertices_left = set(self.cities)
        edges = dict()

        source = random.sample(vertices_left, 1)[0]
        start = source
        vertices_left.remove(source)

        while len(vertices_left) > 0:
            sink = random.sample(vertices_left, 1)[0]
            vertices_left.remove(sink)
            edges[source] = sink
            source = sink

        edges[source] = start

        return Tour(edges)

    def random_grasp_tour(self, alpha):
        """
        pick a random greedyish tour for the GRASP metaheuristic
        at each step, we look at the distance to all possible next cities, then we
        pick randomly from those that add a distance <= c, where
        c = min + alpha * (max - min)
        """
        vertices_left = set(self.cities)
        edges = dict()

        source = random.sample(vertices_left, 1)[0]
        start = source
        vertices_left.remove(source)

        while len(vertices_left) > 0:
            dist_dict = {
                city: euclidean_distance(source, city) for city in vertices_left
            }
            c_max = max(dist_dict.values())
            c_min = min(dist_dict.values())
            cutoff = c_min + alpha * (c_max - c_min)
            sink = random.choice([c for c in vertices_left if dist_dict[c] <= cutoff])
            vertices_left.remove(sink)
            edges[source] = sink
            source = sink
        edges[source] = start
        return Tour(edges)

    def k_opt(self, tour, k, try_only_random_matching=False):
        # print("starting with: ",tour.edges)
        vertices = list(tour.edges.keys())

        sources_to_break = random.sample(vertices, k)
        sinks_to_break = [tour.edges[v] for v in sources_to_break]

        # print("sources:",sources_to_break)
        # print("sinks:",sinks_to_break)

        partial_edges = tour.edges.copy()
        score_loss = sum(
            euclidean_distance(source, partial_edges[source])
            for source in sources_to_break
        )
        for source in sources_to_break:
            del partial_edges[source]

        open_vertices = sources_to_break + sinks_to_break

        if try_only_random_matching:
            matchings = [one_random_matching(open_vertices)]
        else:
            matchings = list(generate_matchings(open_vertices))

        best_score_change = None
        best_edges = None

        while len(matchings) > 0:
            matching = matchings.pop(0)
            score_gain = sum(euclidean_distance(v1, v2) for (v1, v2) in matching)
            # print("matching: ", matching)

            edge_list = list(partial_edges.items())
            edge_list += matching

            # print("edge list: ", edge_list)

            # we now have a bunch of pairs (v_i. v_j) that we have to try to
            # reconstruct into a tour, if possible

            new_edges = []
            active = edge_list[0][0]
            # print("\tactive =", active)
            start = active
            new_edges.append(active)
            valid_edges = True

            while len(edge_list) > 0:
                relevant_edges = [e for e in edge_list if active in e]
                # print("\trelevant =", relevant_edges)
                assert len(relevant_edges) > 0
                next_vertex = [v for v in relevant_edges[0] if v != active][0]
                # print("\tnext_vertex =", next_vertex)
                new_edges.append(next_vertex)
                active = next_vertex
                edge_list.remove(relevant_edges[0])

                if active == start:
                    # print("done. len =",len(edge_list))
                    if len(edge_list) > 0:
                        valid_edges = False
                    break

            valid_one_random = False
            if valid_edges:
                score_change = score_gain - score_loss
                if best_score_change is None or score_change < best_score_change:
                    # print("new best edges =", new_edges)
                    best_score_change = score_change
                    best_edges = new_edges
                    if try_only_random_matching:
                        valid_one_random = True

            if try_only_random_matching:
                if valid_one_random:
                    break
                else:
                    matchings.append(one_random_matching(open_vertices))

        new_dict = {}
        for i in range(len(best_edges) - 1):
            new_dict[best_edges[i]] = best_edges[i + 1]
        # new_dict[best_edges[-1]] = best_edges[0]
        return Tour(new_dict)  # , best_score_change

    @staticmethod
    def partially_mapped_crossover(tour1, tour2):
        assert len(tour1.edges) == len(tour2.edges)
        n = len(tour1.edges)
        city_list = list(tour1.edges.keys())

        # build tours into list form because it will be easier
        T1 = [random.choice(city_list)]
        while len(T1) < n:
            T1.append(tour1.edges[T1[-1]])
        T2 = [random.choice(city_list)]
        while len(T2) < n:
            T2.append(tour2.edges[T2[-1]])

        cutpoints = sorted(random.sample(list(range(n + 1)), 2))
        child = [None for _ in range(n)]

        # copy part between cutpoints from parent #2 into child
        for index in range(cutpoints[0], cutpoints[1]):
            child[index] = T2[index]

        # build the map that fixes duplicates
        middle_map = {T2[i]: T1[i] for i in range(cutpoints[0], cutpoints[1])}

        # now try to build the rest from T1
        for index in range(n):
            if index >= cutpoints[0] and index < cutpoints[1]:
                # we're in the middle, do nothing
                assert child[index] is not None
                continue
            child[index] = T1[index]
            while child[index] in middle_map:
                child[index] = middle_map[child[index]]

        # now child is done, convert back to dict form
        child_dict = {child[i]: child[i + 1] for i in range(n - 1)}
        child_dict[child[-1]] = child[0]
        return Tour(child_dict)

    @staticmethod
    def order_crossover(tour1, tour2):
        assert len(tour1.edges) == len(tour2.edges)
        n = len(tour1.edges)
        city_list = list(tour1.edges.keys())

        # build tours into list form because it will be easier
        T1 = [random.choice(city_list)]
        # T1 = [(1, 1)]
        while len(T1) < n:
            T1.append(tour1.edges[T1[-1]])
        T2 = [random.choice(city_list)]
        # T2 = [(1, 1)]
        while len(T2) < n:
            T2.append(tour2.edges[T2[-1]])

        cutpoints = sorted(random.sample(list(range(n + 1)), 2))
        # cutpoints = [3, 6]
        # print(cutpoints)
        child = [None for _ in range(n)]

        # copy part between cutpoints from parent #2 into child
        for index in range(cutpoints[0], cutpoints[1]):
            child[index] = T2[index]
        # print(child)

        # now we put in entries of T1 in the order they apppear, after the second
        # cutpoint
        placement_index = cutpoints[1] % n
        T1_index = cutpoints[1] % n
        while any(c is None for c in child):
            while T1[T1_index] in child:
                T1_index = (T1_index + 1) % n
            child[placement_index] = T1[T1_index]
            placement_index = (placement_index + 1) % n

        # now child is done, convert back to dict form
        child_dict = {child[i]: child[i + 1] for i in range(n - 1)}
        child_dict[child[-1]] = child[0]
        return Tour(child_dict)

    @staticmethod
    def cycle_crossover(tour1, tour2):
        assert len(tour1.edges) == len(tour2.edges)
        n = len(tour1.edges)
        city_list = list(tour1.edges.keys())

        # build tours into list form because it will be easier
        # T1 = [random.choice(city_list)]
        # T1 = [(1, 1)]
        T1 = [min(city_list)]
        while len(T1) < n:
            T1.append(tour1.edges[T1[-1]])
        # T2 = [random.choice(city_list)]
        T2 = [min(city_list)]
        while len(T2) < n:
            T2.append(tour2.edges[T2[-1]])

        child = [None for _ in range(n)]

        while any(c is None for c in child):

            active_index = min(i for i in range(n) if child[i] is None)
            # print("new loop: AI=", active_index)

            if random.randint(0, 1) == 0:
                active_parent = T1
                non_active_parent = T2
            else:
                active_parent = T2
                non_active_parent = T1

            child[active_index] = active_parent[active_index]
            # print(f"taking {active_parent[active_index]}")
            other_entry = non_active_parent[active_index]

            while other_entry not in child:
                active_index = active_parent.index(other_entry)
                child[active_index] = other_entry
                other_entry = non_active_parent[active_index]
                # print(active_index)
                # print(f"taking {other_entry}")

        # now child is done, convert back to dict form
        child_dict = {child[i]: child[i + 1] for i in range(n - 1)}
        child_dict[child[-1]] = child[0]
        return Tour(child_dict)

    @staticmethod
    def graph_partitioning_crossover(tour1, tour2):
        pass

    @staticmethod
    def _graph_partition(tour1, tour2):

        vertices = set(tour1.edges.values())
        print(vertices)
        cities = set(vertices)
        checked = set()

        t1_neighbor_dict = {
            c: {tour1.edges[c], [k for k in cities if tour1.edges[k] == c][0]}
            for c in cities
        }
        t2_neighbor_dict = {
            c: {tour2.edges[c], [k for k in cities if tour2.edges[k] == c][0]}
            for c in cities
        }

        components = []
        while len(vertices) > 0:
            # find one component
            to_check = {vertices.pop()}
            active_component = []
            print("new component")
            while len(to_check) > 0:
                active = to_check.pop()
                print(active)
                if active in checked:
                    continue
                checked.add(active)
                active_component.append(active)
                if active in vertices:
                    vertices.remove(active)
                # need to find the neighbors of "active" that
                # (1) don't go through edges in both tours
                # (2) ... something about vertices of degree 4???
                #     I think don't go through them at all?

                # if active is a vertex of degree 4 itself, we do not go anywhere from
                # here, it's a dead end
                if (
                    len(
                        set(t1_neighbor_dict[active]).union(
                            set(t2_neighbor_dict[active])
                        )
                    )
                    == 4
                ):
                    continue

                # now we want to look at neighbors that are not neighbors of both and
                # look there as well

                # find the two cities that "active" points to in tour1, and then
                # restrict to only those that have not been considered yet
                t1_neighbors = [
                    t1n for t1n in t1_neighbor_dict[active] if t1n in vertices
                ]

                # find the two cities that "active" points to in tour2, and then
                # restrict to only those that have not been considered yet
                t2_neighbors = [
                    t2n for t2n in t2_neighbor_dict[active] if t2n in vertices
                ]

                for N1 in t1_neighbors:
                    if N1 not in t2_neighbors:
                        # allowed to explore N1 as part of this component
                        to_check.add(N1)
                for N2 in t2_neighbors:
                    if N2 not in t1_neighbors:
                        # allowed to explore N2 as part of this component
                        to_check.add(N2)
            components.append(active_component)

        return components
