# optimal score ~ 0.67366

import matplotlib.pyplot as plt
import matplotlib.cm as cm

import numpy as np
import math
import random
import time


class ContinuousFunction:

    def __init__(self, eval_function, bounds, tweak_delta):
        self.eval_function =  eval_function
        self.bounds = bounds
        self.tweak_delta = tweak_delta

    def random(self):
        pt = []
        for (_min, _max) in self.bounds:
            diff = _max - _min
            pt.append(random.random() * diff + _min)
        return pt

    def in_bounds(self, point):
        return all(
            p >= bnd[0] and p <= bnd[1]
            for (p, bnd) in zip(point, self.bounds)
        )

    def tweak(self, point):
        delta = self.tweak_delta

        new_point = [coord + delta * (2 * random.random() - 1) for coord in point]
        while not self.in_bounds(new_point):
            new_point = [coord + delta * (2 * random.random() - 1) for coord in point]
        return new_point

    def score(self, point):
        return self.eval_function(*point)


class GeometricSimulatedAnnealing:

    def __init__(self, space, alpha, maximization=True):
        self.space = space
        self.maximization = maximization
        self.initial_temp = self.determine_initial_temp()
        self.alpha = alpha
        self.temp = self.initial_temp
        

    def determine_initial_temp(self, trials=1000, initial_prob=0.95):
        total_change = 0
        decreases = 0
        tried = 0
        while decreases < trials:
            tried += 1
            sol = self.space.random()
            new_sol = self.space.tweak(sol)
            diff = self.space.score(new_sol) - self.space.score(sol)
            if not self.maximization:
                diff *= -1
            if diff < 0:
                decreases += 1
                total_change += diff
        avg_change = total_change / trials
        return avg_change / math.log(initial_prob)

    def advance_temp(self):
        self.temp *= self.alpha

    def accept(self, old_score, new_score):
        score_change = new_score - old_score
        
        if not self.maximization:
            score_change *= -1

        if score_change > 0:
            return True
        else:
            accept_probability = math.exp(score_change/self.temp)
            return random.random() < accept_probability

    def is_worse(self, old_score, new_score):
        score_change = new_score - old_score
        
        if not self.maximization:
            score_change *= -1

        return score_change < 0


class ContinuousOptimizationPlotter:
    """
    3D funtions only
    """

    def __init__(self, contour_grids, levels=None, maximization=True):
        plt.ion()
        plt.style.use('ggplot')
        self.fig, ((self.ax_contour, self.ax_score), (self.ax_temp, self.ax_accept)) = plt.subplots(2,2)

        self.score_line, = self.ax_score.plot([], [])
        self.score_values = []
        self.score_x = []

        self.explore_line, = self.ax_contour.plot([], [], linestyle='-', marker='.', color='red')
        self.explored_points = []
        self.best = None

        self.temp_line, = self.ax_temp.plot([], [], linestyle='-', marker='.', color='blue')
        self.temps = []

        self.accept_line, = self.ax_accept.plot([], [], linestyle='-', marker='.', color='green')
        self.accepts = []
        self.ax_accept.set_ylim((0, 1))


        plt.subplots_adjust(bottom=0.15)
        self.ax_contour.set_title("Contour Plot", fontsize=20)
        self.ax_score.set_title("Value", fontsize=20)
        self.ax_temp.set_title("Temperature", fontsize=20)
        self.ax_accept.set_title("% Worsenings Accepted", fontsize=20)
        plt.show(block=False)

        self.props = dict(boxstyle="round", facecolor="blue", alpha=0.5)
        self.text = self.ax_score.text(
            0,
            -0.1,
            f"Best Score: {self.best}",
            transform=self.ax_score.transAxes,
            fontsize=14,
            verticalalignment="top",
            bbox=self.props,
        )

        X, Y, Z = contour_grids
        if levels is None:
            self.ax_contour.contour(X, Y, Z)
        else:
            self.ax_contour.contour(X, Y, Z, levels=levels)

        self.maximization = maximization

        plt.pause(0.0001)

    @staticmethod
    def points_to_xy(points):
        x_points = [pt[0] for pt in points]
        y_points = [pt[1] for pt in points]
        return x_points, y_points

    def add_explore_point(self, point, new_line=False, update_plot=True):

        if new_line:
            self.explore_line.set_data(*self.points_to_xy(self.explored_points))
            self.explore_line, = ax_contour.plot([], [], linestyle='-', marker='.', color='red')
            self.explored_points = []

        self.explored_points.append(point)
        if len(self.explored_points) > 3000:
            self.explored_points.pop(0)
        if update_plot:
            self.explore_line.set_data(*self.points_to_xy(self.explored_points))

    def add_to_score_line(self, value, update_plot=True):
        self.score_values.append(value)

        if len(self.score_x) == 0:
            self.score_x.append(0)
        else:
            self.score_x.append(self.score_x[-1]+1)

        if len(self.score_values) > 10000:
            self.score_values.pop(0)
            self.score_x.pop(0)

        if update_plot:

            self.score_line.set_data(self.score_x, self.score_values)
            cur_x = self.ax_score.get_xlim()
            cur_y = self.ax_score.get_ylim()

            if self.score_x[-1] > cur_x[1]:
                self.ax_score.set_xlim((self.score_x[0], self.score_x[-1]))
            if max(self.score_values) > cur_y[1] or min(self.score_values) < cur_y[0]:
                self.ax_score.set_ylim((min(0, min(self.score_values)), max(self.score_values)))


    def add_to_temp_line(self, temp):
        self.temps.append(temp)
        self.temp_line.set_data(range(len(self.temps)), self.temps)

        cur_x = self.ax_temp.get_xlim()
        cur_y = self.ax_temp.get_ylim()

        if len(self.temps) > cur_x[1]:
            self.ax_temp.set_xlim((0, cur_x[1] + 10))
        if max(self.temps) > cur_y[1]:
            self.ax_temp.set_ylim((0, max(self.temps)))


    def add_to_accept_line(self, accept):
        self.accepts.append(accept)
        self.accept_line.set_data(range(len(self.accepts)), self.accepts)

        cur_x = self.ax_accept.get_xlim()
        if len(self.accepts) > cur_x[1]:
            self.ax_accept.set_xlim((0, cur_x[1] + 10))

                      
    def update_best(self, value):
        if (
            self.best is None
            or (self.maximization and value > self.best)
            or ((not self.maximization) and value < self.best)
        ):
            self.best = value
            self.text.remove()
            self.text = self.ax_score.text(
                0,
                -0.05,
                f"Best Score: {self.best}",
                transform=self.ax_score.transAxes,
                fontsize=14,
                verticalalignment="top",
                bbox=self.props,
            )




grid_delta = 0.01
tp = 2 * math.pi
x = np.arange(-tp, tp, grid_delta)
y = np.arange(-tp, tp, grid_delta)
X, Y = np.meshgrid(x, y)
Z = np.sin(X-Y)**2 * np.sin(X+Y)**2 / np.sqrt(X**2 + Y**2)
levels = np.arange(0, 0.7, 0.03)

plotter = ContinuousOptimizationPlotter([X, Y, Z], levels, maximization=True)

eval_function = lambda x,y: math.sin(x-y)**2 * math.sin(x+y)**2 / math.sqrt(x**2 + y**2)
bounds = [[-tp, tp]]*2
tweak_delta = 0.1
cns_func = ContinuousFunction(eval_function, bounds, tweak_delta)

SA = GeometricSimulatedAnnealing(cns_func, 0.98, maximization=True)

point = cns_func.random()
value = cns_func.score(point)


while SA.temp > 0.005 * SA.initial_temp:

    plotter.add_to_temp_line(SA.temp)
    
    total = 0
    accept = 0

    for i in range(1000):
        

        update_plot = (i % 1000 == 0)

        plotter.add_explore_point(point, update_plot=update_plot)
        plotter.update_best(value)
        plotter.add_to_score_line(value, update_plot=update_plot)

        if update_plot:
            plt.pause(0.00001)

        new_point = cns_func.tweak(point)
        new_value = cns_func.score(new_point)

        if SA.is_worse(value, new_value):
            total += 1
        if SA.accept(value, new_value):
            if SA.is_worse(value, new_value):
                accept += 1
            point = new_point
            value = new_value

    plotter.add_to_accept_line(accept/total)
    SA.advance_temp()

plt.show(block=True)