import matplotlib.pyplot as plt  # type: ignore


class TSPPlotter:
    def __init__(self):
        plt.ion()
        plt.style.use("ggplot")
        self.fig, (self.ax_tour, self.ax_score) = plt.subplots(1, 2)

        (self.score_line,) = self.ax_score.plot([], [])
        self.score_values = []
        self.score_x = []

        self.ax_tour.tick_params(
            which="both", bottom=False, labelbottom=False, left=False, labelleft=False
        )
        self.tour_plot = None

        self.best = None

        plt.subplots_adjust(bottom=0.15)
        self.ax_tour.set_title("Tour", fontsize=20)
        self.ax_score.set_title("Value", fontsize=20)
        plt.show(block=False)

        self.props = dict(boxstyle="round", facecolor="blue", alpha=0.5)
        self.text = self.ax_score.text(
            0,
            -0.1,
            f"Best Score: {self.best}",
            transform=self.ax_score.transAxes,
            fontsize=14,
            verticalalignment="top",
            bbox=self.props,
        )

        plt.pause(0.0001)

    def show_tour(self, tour, update_plot=True):
        # print(tour.edges)
        cities = list(tour.edges.keys())
        points = [cities[0]]
        while len(points) < len(cities):
            points.append(tour.edges[points[-1]])
        points.append(cities[0])

        # print("**",len(points))

        if update_plot:
            x_points = [p[0] for p in points]
            y_points = [p[1] for p in points]

            if self.tour_plot is not None:
                self.tour_plot.remove()
            (self.tour_plot,) = self.ax_tour.plot(
                x_points + [x_points[0]],
                y_points + [y_points[0]],
                color="blue",
                marker="s",
                markerfacecolor="black",
                markersize=3,
            )

    def add_to_score_line(self, value, update_plot=True):
        self.score_values.append(value)

        if len(self.score_x) == 0:
            self.score_x.append(0)
        else:
            self.score_x.append(self.score_x[-1] + 1)

        if len(self.score_values) > 10000:
            self.score_values.pop(0)
            self.score_x.pop(0)

        if update_plot:

            self.score_line.set_data(self.score_x, self.score_values)
            cur_x = self.ax_score.get_xlim()
            cur_y = self.ax_score.get_ylim()

            if self.score_x[-1] > cur_x[1]:
                self.ax_score.set_xlim((self.score_x[0], self.score_x[-1]))
            if max(self.score_values) > cur_y[1] or min(self.score_values) < cur_y[0]:
                self.ax_score.set_ylim(
                    (min(0, min(self.score_values)), max(self.score_values))
                )

    def update_best(self, value):
        if self.best is None or value < self.best:
            self.best = value
            self.text.remove()
            self.text = self.ax_score.text(
                0,
                -0.05,
                f"Best Score: {self.best}",
                transform=self.ax_score.transAxes,
                fontsize=14,
                verticalalignment="top",
                bbox=self.props,
            )
