import numpy as np
import math
import random
import time


def euclidean_distance(pt1, pt2):
    return math.sqrt((pt1[0] - pt2[0]) ** 2 + (pt1[1] - pt2[1]) ** 2)


class ContinuousFunction:
    """
    This is a class that represents the function we are trying to optimize.
    eval_function is a lambda function representing the function we're
       optimizing, and bounds represents the min and max allowed value in
       each dimension (so, this assumes a rectangular boundary)
    """

    def __init__(self, eval_function, bounds):
        self.eval_function = eval_function
        self.bounds = bounds

    """
    Return a random point within the bounds
    """
    def random(self):
        pt = []
        for (_min, _max) in self.bounds:
            diff = _max - _min
            pt.append(random.random() * diff + _min)
        return pt

    """
    Check if a point is within the bounds
    """
    def in_bounds(self, point):
        return all(p >= bnd[0] and p <= bnd[1] for (p, bnd) in zip(point, self.bounds))

    """
    Plug the point into the function to find its value at this point
    """
    def score(self, point):
        return self.eval_function(*point)


class Firefly:
    """
    This class represents a single firefly that will move around the
    search space. At the moment, it only stores a single position.
    """

    def __init__(self, position):
        self.position = np.array(position)


class FireflyManager:
    """
    This is a class that handles the main routine, including managing
    all of the fireflies
    """

    def __init__(
        self,
        space,
        dimension,
        N,
        levy_parameter,
        alpha=1,
        beta=1,
        gamma=1,
        maximization=True,
    ):
        """
        space = object representing the problem in this case a ContinuousFunction objec
        dimension = dimension of the problem
        N = number of particles
        levy_parameter = which levy flight parameter to use -> bigger = less frequent large jumps
        maximization = true if maximizing, false if minimizing
        """
        self.space = space
        self.dimension = dimension
        self.N = N
        self.maximization = maximization

        self.levy_parameter = levy_parameter
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma

        self.fireflies = [self.random_firefly() for _ in range(N)]

        self.global_best_sol = None
        self.global_best_score = None
        self.set_global_best()

    def is_better(self, new_score, old_score):
        """
        return if new_score is better than old_score (depends if we're
        maximizing or minimizing)
        """
        if self.maximization:
            return new_score > old_score
        return new_score < old_score

    def set_global_best(self):
        """
        Loops over each firefly and computes its score. If it's the best
        we've ever seen, remember it.
        """
        for firefly in self.fireflies:
            score = self.space.score(firefly.position)
            if self.global_best_score is None or self.is_better(
                score, self.global_best_score
            ):
                self.global_best_score = score
                self.global_best_sol = firefly.position.copy()

    def random_firefly(self):
        return Firefly(self.space.random())

    def levy(self):
        """
        Compute a levy jump
        """
        return np.array(
            [
                random.choice([-1, 1]) * (random.random() ** (-1 / self.levy_parameter) - 1)
                for _ in range(self.dimension)
            ]
        )

    def advance(self):
        """
        Moves all the fireflies, then checks for a best score
        """
        for f1 in self.fireflies:
            for f2 in self.fireflies:
                if f1 == f2:
                    continue
                if not self.space.in_bounds(f2.position):
                    continue

                f1_score = self.space.score(f1.position)
                f2_score = self.space.score(f2.position)

                if self.is_better(f2_score, f1_score):
                    # move f1 toward f2
                    #print(f"moving {f1.position} toward {f2.position} because {f2_score} is better than {f1_score}.")
                    attraction = (
                        self.beta
                        * math.exp(
                            -self.gamma * euclidean_distance(f1.position, f2.position)**2
                        )
                        * (f2.position - f1.position)
                    )
                    levy = (
                        self.alpha
                        * np.random.random_sample((self.dimension,))
                        * self.levy()
                    )
                    f1.position += attraction + levy
                    #print(f"attraction = {attraction}")
                    #print(f"levy = {levy}")
                    #print(f"new position = {f1.position}")

        # Now check if any of the fireflies are out of bounds. If so, delete them and add
        # a new random one.
        to_remove = set()
        for (index, firefly) in enumerate(self.fireflies):
            if any(
                pos < self.space.bounds[d][0] or pos > self.space.bounds[d][1]
                for (d, pos) in enumerate(firefly.position)
            ):
                to_remove.add(index)
        self.fireflies = [
            self.fireflies[i]
            for i in range(self.N)
            if i not in to_remove
        ] + [self.random_firefly() for _ in to_remove]


        random.shuffle(self.fireflies)

        self.set_global_best()


eval_function = lambda x,y: -(math.sin(x)*math.sin(x**2/math.pi)**20 + math.sin(y)*math.sin(2*y**2/math.pi)**20)
bounds = [[0, 4]]*2
cns_func = ContinuousFunction(eval_function, bounds)

FF = FireflyManager(cns_func, 2, 5, 2, alpha=0.05, beta=0.1, gamma=0.1, maximization=True)
gen = 0
last_best = 0

while True:
    gen += 1
    current_scores = [cns_func.score(firefly.position) for firefly in FF.fireflies]
    FF.advance()
    if FF.global_best_score != last_best:
        print(f"Gen {gen}: best score = {FF.global_best_score} at point {FF.global_best_sol}")
        last_best = FF.global_best_score

