"""
    Weighted Interval Scheduling - live coding
    1. N meeting requests, each with a start time and end time and a value
    2. Goal is to accept the subset of requests with largest total value
        but no conflicts
"""

from itertools import combinations
import math
import random


class Meeting:
    def __init__(self, start, end, value):
        self.start = start
        self.end = end
        self.value = value

    def __repr__(self):
        return f"Meeting({self.start}, {self.end}, {self.value})"

    def __str__(self):
        return f"(({self.start}, {self.end}), {self.value})"  # ((3, 5), 2)

    def __eq__(self, other):
        return (
            self.start == other.start
            and self.end == other.end
            and self.value == other.value
        )

    def __hash__(self):
        return hash((self.start, self.end, self.value))

    def compatible(self, other):
        return self.end <= other.start or self.start >= other.end


class Solution:
    """
    A solution is a set of meetings.
    It is not necessarily a VALID solution.
    """

    def __init__(self, meetings):
        self.meetings = set(meetings)

    def score(self):
        return sum(m.value for m in self.meetings)

    def is_valid(self):
        return all(m1.compatible(m2) for m1, m2 in combinations(self.meetings, 2))

    def tweak(self, all_meetings):
        assert isinstance(all_meetings, list)

        if random.randint(0, 1) == 0:
            # pick a random meeting, if we have it, remove it, otherwise add it

            random_meeting = random.choice(all_meetings)

            # self.meetings = {Meeting(a,b,c), Meeting(d,e,f), Meeting(g,h,i)}

            new_solution = Solution(set(self.meetings))
            if random_meeting in self.meetings:
                new_solution.meetings.remove(random_meeting)
            else:
                new_solution.meetings.add(random_meeting)
            return new_solution

        else:
            # remove one meeting, then add another one from all_meetings
            new_solution = Solution(set(self.meetings))
            if len(new_solution.meetings) > 0:
                meeting_to_remove = random.choice(list(new_solution.meetings))
                new_solution.meetings.remove(meeting_to_remove)
            meetings_left = list(set(all_meetings) - set(self.meetings))
            if meetings_left:
                meeting_to_add = random.choice(meetings_left)
                new_solution.meetings.add(meeting_to_add)
            return new_solution


# generate one random request with start/end in the set [0, 1, ..., 9]
#   and value a uniform real number between 0 and 10
# repeatedly re-randomizes until meeting lasts at least 5 minutes
def random_request():
    req = [sorted(random.sample(range(100), 2)), random.random() * 10]
    while req[0][1] - req[0][0] < 5:
        req = [sorted(random.sample(range(100), 2)), random.random() * 10]
    return Meeting(req[0][0], req[0][1], req[1])


# generate n random requests
def make_requests(n):
    return [random_request() for i in range(n)]


def random_valid_solution(all_meetings):
    sol = Solution([])
    all_meetings_shuffled = list(all_meetings)
    random.shuffle(all_meetings_shuffled)
    for meeting in all_meetings_shuffled:
        new_sol = Solution(set(sol.meetings))
        new_sol.meetings.add(meeting)
        if new_sol.is_valid():
            sol = new_sol
    assert sol.is_valid()
    return sol


def hill_climbing(all_meetings):

    # start with a random valid solution
    sol = random_valid_solution(all_meetings)

    improvement_failures = 0

    # run until 10k consecutive tweaks all get worse
    while improvement_failures < 1_000:
        new_sol = sol.tweak(all_meetings)
        while not new_sol.is_valid():
            new_sol = sol.tweak(all_meetings)
        if new_sol.score() >= sol.score():
            improvement_failures = 0
            sol = new_sol
            # print(
            #     f"New best score: {sol.score()}. Sol uses {len(sol.meetings)} meetings.",
            #     flush=True,
            # )
        else:
            improvement_failures += 1

    return sol


def hill_climbing_with_random_restarts(all_meetings):
    best_score = 0
    while True:
        hc = hill_climbing(all_meetings)
        if hc.score() > best_score:
            best_score = hc.score()
        print(
            f"Found a solution with score {hc.score()} with {len(hc.meetings)} meetings. Best ever: {best_score}."
        )


def simulated_annealing(all_meetings):
    initial_temperature = 10
    alpha = 0.98
    temperature = initial_temperature
    best_sol = None

    # start with a random valid solution
    sol = random_valid_solution(all_meetings)

    while temperature > initial_temperature / 500:
        worse_solutions_found = 0
        worse_solutions_accepted = 0

        for _ in range(1_000):
            new_sol = sol.tweak(all_meetings)
            while not new_sol.is_valid():
                new_sol = sol.tweak(all_meetings)
            delta = new_sol.score() - sol.score()
            if delta > 0:
                sol = new_sol
                if best_sol is None or best_sol.score() < sol.score():
                    best_sol = sol
            else:
                # this is a worse solution
                worse_solutions_found += 1
                acceptance_probability = math.exp(delta / temperature)
                if random.random() < acceptance_probability:
                    worse_solutions_accepted += 1
                    sol = new_sol
        temperature *= alpha
        print(
            f"T: {temperature}, current score: {sol.score()}, best score: {best_sol.score()}, worse accepted = {worse_solutions_accepted}/{worse_solutions_found} = {worse_solutions_accepted/worse_solutions_found*100:.2f}%."
        )
    return best_sol


random_data = make_requests(100)

# print(random_data)

simulated_annealing(random_data)
hill_climbing_with_random_restarts(random_data)
