import matplotlib.pyplot as plt
import matplotlib.cm as cm

import numpy as np
import math
import random
import time


def euclidean_distance(pt1, pt2):
    return math.sqrt((pt1[0] - pt2[0]) ** 2 + (pt1[1] - pt2[1]) ** 2)


class ContinuousFunction:
    def __init__(self, eval_function, bounds, tweak_delta):
        self.eval_function = eval_function
        self.bounds = bounds
        self.tweak_delta = tweak_delta

    def random(self):
        pt = []
        for (_min, _max) in self.bounds:
            diff = _max - _min
            pt.append(random.random() * diff + _min)
        return pt

    def in_bounds(self, point):
        return all(p >= bnd[0] and p <= bnd[1] for (p, bnd) in zip(point, self.bounds))

    def score(self, point):
        return self.eval_function(*point)


class Nest:
    def __init__(self, egg):
        self.egg = egg

class Egg:
    def __init__(self, position):
        self.position = np.array(position)


class CuckooManager:
    def __init__(
        self,
        space,
        dimension,
        N,
        levy_parameter,
        alpha,
        percent_to_fly,
        maximization=True,
    ):
        self.space = space
        self.dimension = dimension
        self.N = N
        self.maximization = maximization

        self.levy_parameter = levy_parameter
        self.alpha = alpha
        self.percent_to_fly = percent_to_fly

        self.nests = [self.random_nest() for _ in range(N)]

        self.global_best_sol = None
        self.global_best_score = None
        self.set_global_best()

    def is_better(self, new_score, old_score):
        if self.maximization:
            return new_score > old_score
        return new_score < old_score

    def set_global_best(self):
        for nest in self.nests:
            score = self.space.score(nest.egg.position)
            if self.global_best_score is None or self.is_better(
                score, self.global_best_score
            ):
                self.global_best_score = score
                self.global_best_sol = nest.egg.position.copy()

    def random_nest(self):
        return Nest(Egg(self.space.random()))

    def levy(self):
        return np.array(
            [
                self.alpha * random.choice([-1, 1]) * (random.random() ** (-1 / self.levy_parameter) - 1)
                for _ in range(self.dimension)
            ]
        )

    def advance_one_nest(self):
        """
        pick a random nest, move its egg with a levy flight to get sol S
        pick a new random nest, if S is better than the egg in that nest, replace
          that egg with S
        """
        nest1, nest2 = random.sample(self.nests, 2)
        
        new_egg = Egg(nest1.egg.position + self.levy())

        if self.is_better(
            self.space.score(new_egg.position),
            self.space.score(nest2.egg.position),
        ) and self.space.in_bounds(new_egg.position):
            nest2.egg = new_egg

    def drop_lowest_percentage(self):
        self.nests.sort(key=lambda nest: self.space.score(nest.egg.position))
        num_to_drop = math.floor(self.percent_to_fly * len(self.nests))

        for nest in self.nests[:num_to_drop]:
            new_pos = nest.egg.position + self.levy()
            if self.space.in_bounds(new_pos):
                nest.egg = Egg(new_pos)



class ContinuousOptimizationPlotter:
    """
    3D funtions only
    """

    def __init__(self, contour_grids, levels=None, maximization=True):
        plt.ion()
        plt.style.use("ggplot")
        self.fig, (self.ax_contour, self.ax_score) = plt.subplots(1, 2)

        (self.score_line,) = self.ax_score.plot([], [])
        self.score_values = []
        self.score_x = []

        (self.explore_line,) = self.ax_contour.plot(
            [], [], linestyle="", marker="o", color="red"
        )
        self.explored_points = []
        self.best = None

        plt.subplots_adjust(bottom=0.15)
        self.ax_contour.set_title("Contour Plot", fontsize=20)
        self.ax_score.set_title("Value", fontsize=20)
        plt.show(block=False)

        self.props = dict(boxstyle="round", facecolor="blue", alpha=0.5)
        self.text = self.ax_score.text(
            0,
            -0.1,
            f"Best Score: {self.best}",
            transform=self.ax_score.transAxes,
            fontsize=14,
            verticalalignment="top",
            bbox=self.props,
        )

        X, Y, Z = contour_grids
        if levels is None:
            self.ax_contour.contour(X, Y, Z)
        else:
            self.ax_contour.contour(X, Y, Z, levels=levels)

        self.maximization = maximization

        plt.pause(0.0001)

    @staticmethod
    def points_to_xy(points):
        x_points = [pt[0] for pt in points]
        y_points = [pt[1] for pt in points]
        return x_points, y_points

    def update_particles(self, particles, update_plot=True):
        if update_plot:
            self.explore_line.set_data(
                *self.points_to_xy([p.position for p in particles])
            )
            # print([p.position for p in particles])

    def add_to_score_line(self, value, update_plot=True):
        self.score_values.append(value)

        if len(self.score_x) == 0:
            self.score_x.append(0)
        else:
            self.score_x.append(self.score_x[-1] + 1)

        if len(self.score_values) > 10000:
            self.score_values.pop(0)
            self.score_x.pop(0)

        if update_plot:

            self.score_line.set_data(self.score_x, self.score_values)
            cur_x = self.ax_score.get_xlim()
            cur_y = self.ax_score.get_ylim()

            if self.score_x[-1] > cur_x[1]:
                self.ax_score.set_xlim((self.score_x[0], self.score_x[-1]))
            if max(self.score_values) > cur_y[1] or min(self.score_values) < cur_y[0]:
                self.ax_score.set_ylim(
                    (min(0, min(self.score_values)), max(self.score_values))
                )

    def update_best(self, value):
        if (
            self.best is None
            or (self.maximization and value > self.best)
            or ((not self.maximization) and value < self.best)
        ):
            self.best = value
            self.text.remove()
            self.text = self.ax_score.text(
                0,
                -0.05,
                f"Best Score: {self.best}",
                transform=self.ax_score.transAxes,
                fontsize=14,
                verticalalignment="top",
                bbox=self.props,
            )


grid_delta = 0.01
tp = 2 * math.pi
x = np.arange(-tp, tp, grid_delta)
y = np.arange(-tp, tp, grid_delta)
X, Y = np.meshgrid(x, y)
Z = np.sin(X - Y) ** 2 * np.sin(X + Y) ** 2 / np.sqrt(X ** 2 + Y ** 2)
levels = np.arange(0, 0.7, 0.03)

plotter = ContinuousOptimizationPlotter([X, Y, Z], levels, maximization=True)

eval_function = (
    lambda x, y: math.sin(x - y) ** 2
    * math.sin(x + y) ** 2
    / math.sqrt(x ** 2 + y ** 2)
)
bounds = [[-tp, tp]] * 2
tweak_delta = 0.1
cns_func = ContinuousFunction(eval_function, bounds, tweak_delta)

CM = CuckooManager(cns_func, 2, 10, 2, alpha=0.1, percent_to_fly=0.5, maximization=True)
gen = 0

while True:
    gen += 1
    # print(gen)
    current_scores = [cns_func.score(nest.egg.position) for nest in CM.nests]

    plotter.update_particles([nest.egg for nest in CM.nests])
    plotter.add_to_score_line(max(current_scores))
    CM.set_global_best()
    plotter.update_best(CM.global_best_score)
    plt.pause(0.0001)

    CM.advance_one_nest()
    CM.drop_lowest_percentage()

# plt.show(block=True)
