# Tabu search for Traveling Salesman, 200 cities
# this version scores each tour completely from scratch,
# which is very inefficient

import math
import random
random.seed(1)
import itertools

from collections import defaultdict
from tqdm import tqdm

import matplotlib.pyplot as plt  # type: ignore


def euclidean_distance(pt1, pt2):
    return math.sqrt((pt1[0] - pt2[0]) ** 2 + (pt1[1] - pt2[1]) ** 2)

class TSP:
    def __init__(self, points):
        self.points = set(points)

    def random_solution(self):
        return random.sample(self.points, len(self.points))

    def random_greedy_solution(self):
        start_point = random.choice(list(self.points))
        solution = [start_point]

        points_left = set(self.points)
        points_left.remove(start_point)

        while len(points_left) > 0:
            closest = min(
                points_left, key=lambda p: euclidean_distance(p, solution[-1])
            )
            solution.append(closest)
            points_left.remove(closest)

        return solution

    def score(self, sol):
        return sum(
            euclidean_distance(sol[i], sol[i + 1])
            for i in range(len(sol) - 1)
        ) + euclidean_distance(sol[-1], sol[0])

    def neighborhood(self, sol):
        total = (len(sol)-1)*(len(sol)-2)/2
        for indices in itertools.combinations(range(1,len(sol)),2):
            p = sol
            new_points = p[:indices[0]] + p[indices[0]:indices[1]+1][::-1] + p[indices[1]+1:]
            yield new_points, indices

points = [(random.random(), random.random()) for _ in range(200)]
tsp = TSP(points)

plt.ion()
fig, (ax, ax2) = plt.subplots(1,2)
# fig2, ax2 = plt.subplots()
(data_line,) = ax2.plot([], [])

plt.subplots_adjust(bottom=0.15)
ax.set_title("Traveling Salesman Demo", fontsize=20)
ax.tick_params(
    which="both", bottom=False, labelbottom=False, left=False, labelleft=False
)
plt.show(block=False)
plot = None
text = None
best_sol = None
best_score = None

taboo_time = 100
sol = tsp.random_solution()
taboo = defaultdict(int)
generation = 0


while True:

    generation += 1

    x_points = [p[0] for p in sol]
    y_points = [p[1] for p in sol]
    length = tsp.score(sol)

    previous_line_data = data_line.get_data()
    xdata = list(previous_line_data[0])
    ydata = list(previous_line_data[1])
    xdata.append(len(xdata))
    ydata.append(length)
    data_line.set_data(xdata, ydata)
    ax2.set_xlim((0, len(xdata)))
    ax2.set_ylim((0, max(ydata)))

    if plot is not None:
        plot.remove()
    (plot,) = ax.plot(
        x_points + [x_points[0]],
        y_points + [y_points[0]],
        color="blue",
        marker="s",
        markerfacecolor="black",
        markersize=3,
    )
    props = dict(boxstyle="round", facecolor="blue", alpha=0.5)
    if text is not None:
        text.remove()
    if best_score is None or best_score > length:
        best_score = length
        best_sol = sol
    text = ax.text(
        0,
        -0.05,
        f"Best route length: {best_score}",
        transform=ax.transAxes,
        fontsize=14,
        verticalalignment="top",
        bbox=props,
    )
    plt.pause(0.0001)

    steepest_best_sol = None
    steepest_best_score = None
    steepest_move = None

    for s, move in tsp.neighborhood(sol):
        if taboo[move] > generation:
            continue
        pl = tsp.score(s)
        if steepest_best_score is None or pl < steepest_best_score:
            steepest_best_sol = s
            steepest_best_score = pl
            steepest_move = move

    # if steepest_best_score < length:
    sol = steepest_best_sol
    print(steepest_best_score)
    taboo[steepest_move] = generation + taboo_time
    # else:
        # break

plt.show(block=True)