# random solution AND random greedy solution

import math
import random
random.seed(1)
import itertools
from tqdm import tqdm

import matplotlib.pyplot as plt  # type: ignore


def euclidean_distance(pt1, pt2):
    return math.sqrt((pt1[0] - pt2[0]) ** 2 + (pt1[1] - pt2[1]) ** 2)


class TSPSolution:
    def __init__(self, points):
        self.points = points

    def path_length(self):
        return sum(
            euclidean_distance(self.points[i], self.points[i + 1])
            for i in range(len(self.points) - 1)
        ) + euclidean_distance(self.points[-1], self.points[0])

    def tweak(self):
        indices = random.sample(list(range(1,len(self.points))), 2)
        new_points = list(self.points)
        new_points[indices[0]] = self.points[indices[1]]
        new_points[indices[1]] = self.points[indices[0]]
        return TSPSolution(new_points)

    def neighborhood(self):
        # new_solutions = []
        total = (len(self.points)-1)*(len(self.points)-2)/2
        for indices in itertools.combinations(range(1,len(self.points)),2):
            new_points = list(self.points)
            new_points[indices[0]] = self.points[indices[1]]
            new_points[indices[1]] = self.points[indices[0]]
            yield TSPSolution(new_points)
            # new_solutions.append(TSPSolution(new_points))
        # return new_solutions

    def steepest_ascent(self):
        best_score = self.path_length()
        best_path = self

        for s in self.neighborhood():
            pl = s.path_length()
            if pl < best_score:
                best_path = s
                best_score = pl
        return best_path

class TSPProblem:
    def __init__(self, points):
        self.points = set(points)

    def random_solution(self):
        return TSPSolution(random.sample(self.points, len(self.points)))

    def random_greedy_solution(self):
        start_point = random.choice(list(self.points))
        solution = [start_point]

        points_left = set(self.points)
        points_left.remove(start_point)

        while len(points_left) > 0:
            closest = min(
                points_left, key=lambda p: euclidean_distance(p, solution[-1])
            )
            solution.append(closest)
            points_left.remove(closest)

        return TSPSolution(solution)



points = [(random.random(), random.random()) for _ in range(50)]
tsp_problem = TSPProblem(points)

plt.ion()
fig, (ax, ax2) = plt.subplots(1,2)
# fig2, ax2 = plt.subplots()
(data_line,) = ax2.plot([], [])

plt.subplots_adjust(bottom=0.15)
ax.set_title("Traveling Salesman Demo", fontsize=20)
ax.tick_params(
    which="both", bottom=False, labelbottom=False, left=False, labelleft=False
)
plot = None
text = None
best = None

tsp_solution = tsp_problem.random_solution()

while True:

    x_points = [p[0] for p in tsp_solution.points]
    y_points = [p[1] for p in tsp_solution.points]
    length = tsp_solution.path_length()

    previous_line_data = data_line.get_data()
    xdata = list(previous_line_data[0])
    ydata = list(previous_line_data[1])
    xdata.append(len(xdata))
    ydata.append(length)
    data_line.set_data(xdata, ydata)
    ax2.set_xlim((0, len(xdata)))
    ax2.set_ylim((0, max(ydata)))

    if plot is not None:
        plot.remove()
    (plot,) = ax.plot(
        x_points + [x_points[0]],
        y_points + [y_points[0]],
        color="blue",
        marker="s",
        markerfacecolor="black",
        markersize=3,
    )
    props = dict(boxstyle="round", facecolor="blue", alpha=0.5)
    if text is not None:
        text.remove()
    if best is None or best > length:
        best = length
    text = ax.text(
        0,
        -0.05,
        f"Best route length: {best}",
        transform=ax.transAxes,
        fontsize=14,
        verticalalignment="top",
        bbox=props,
    )
    plt.pause(0.0001)

    # tsp_solution = min(tsp_solution.neighborhood(), key = lambda T: T.path_length())
    tsp_solution = tsp_solution.steepest_ascent()



