import random
import string
import time


############################################################
# knapsack problem specs
############################################################


def make_toy_knapsack():
    names = ["gold", "saffron", "diamond"]
    densities = [50, 60, 500]
    weights = [5, 4, 3]
    values = [density * weight for density, weight in zip(densities, weights)]
    capacity = 10
    return names, values, weights, capacity


def make_names_list(num_items):
    # start with single characters: a, b, c, ...
    roster = list(string.ascii_lowercase)

    # move on to double characters: aa, ab, ac, ... , ba, bb, ...
    for c in string.ascii_lowercase:
        roster += [c + d for d in string.ascii_lowercase]

    return roster[:num_items]


def make_random_knapsack(
    seed=5,
    num_items=10, weight_range=(1, 5), capacity=20,
    float_weights=False,
):
    random.seed(seed)
    names = make_names_list(num_items)
    densities = [random.randint(10, 100) for _ in range(num_items)]
    sampling_func = random.uniform if float_weights else random.randint
    weights = [sampling_func(weight_range[0], weight_range[1]) for _ in range(num_items)]
    values = [densities[i] * weights[i] for i in range(num_items)]
    return names, values, weights, capacity


def group_item_info(names, values, weights):
    n = len(names)
    return [
        (names[i], values[i], weights[i])
        for i in range(n)
    ]


def test_make_knapsack():
    for problem in make_toy_knapsack(), make_random_knapsack():
        print()
        print(problem)
        names, values, weights, capacity = problem
        print(group_item_info(names, values, weights))


# print()
# test_make_knapsack()


def item_names(items):
    return [item[0] for item in items]


def total_value(items):
    return sum([item[1] for item in items])


def total_weight(items):
    return sum([item[2] for item in items])


############################################################
# exhaustive enumeration for optimal 0/1 knapsack solution
############################################################


def powerset(items):
    # base case
    if len(items) == 0:
        return [set()]

    # recursive case
    first, rest = items[0], items[1:]
    combos_without = powerset(rest)
    combos_with = [{first} | combo for combo in powerset(rest)]
    return combos_without + combos_with


# print(powerset(range(3)))


def knapsack_enumerate(items, capacity):
    best_candidate = None
    best_value = 0
    for selection in powerset(items):
        if total_weight(selection) < capacity:
            value = total_value(selection)
            if value > best_value:
                best_candidate = item_names(selection)
                best_value = value
    return best_candidate, best_value


def knapsack_enumerate_comprehension(items, capacity):
    # list comprehension solution
    return max(
        [
            (item_names(selection), total_value(selection))
            for selection in powerset(items)
            if total_weight(selection) < capacity
        ],
        key=lambda x: x[1],
    )


def test_knapsack(knapsack_func):
    print("=" * 40)
    print(f"Test {knapsack_func.__name__}()")
    print("=" * 40)
    print()

    # toy example
    names, values, weights, capacity = make_toy_knapsack()
    items = group_item_info(names, values, weights)
    print(knapsack_func(items, capacity))
    print()

    # random example
    names, values, weights, capacity = make_random_knapsack(seed=5)
    print(f"{names   = }")
    print(f"{values  = }")
    print(f"{weights = }")
    items = group_item_info(names, values, weights)
    print(knapsack_func(items, capacity))
    print()


# print()
# test_knapsack(knapsack_enumerate)
# test_knapsack(knapsack_enumerate_comprehension)


def test_performance(knapsack_func, max_num_items=30):
    print("=" * 40)
    print(f"Solving random knapsack scenarios, sizes 1 to {max_num_items}")
    print("=" * 40)
    print()

    for size in range(1, max_num_items + 1):
        names, values, weights, capacity = make_random_knapsack(
            num_items=size,
            # weight_range=(10, 50),
            # capacity=12800,
        )
        print(f"{size    = }")
        print(f"{names   = }")
        print(f"{values  = }")
        print(f"{weights = }")

        items = group_item_info(names, values, weights)
        start = time.time()
        solution = knapsack_func(items, capacity)
        stop = time.time()
        print(f"{solution = }")
        print(f"duration = {stop - start : .2g} sec")
        print()


# print()
# test_performance(knapsack_enumerate, max_num_items=10)
# test_performance(knapsack_enumerate, max_num_items=40)


############################################################
# prune infeasible branches in decision tree
############################################################


def knapsack_decision_tree(items, capacity):
    # base case: empty items has empty solution
    if len(items) == 0:
        return [], 0

    # recursion: decide whether to take first item, solve for remaining items
    item, remaining = items[0], items[1:]
    name, value, weight = item

    # consider branch where we don't take first item
    sol_without, val_without = knapsack_decision_tree(remaining, capacity)

    # consider branch where we do take first item
    if weight < capacity:
        sol_with, val_with = knapsack_decision_tree(remaining, capacity - weight)
        sol_with.append(name)
        val_with += value
    else:
        sol_with = []
        val_with = 0

    # determine which branch is better
    if val_with > val_without:
        return sol_with, val_with
    else:
        return sol_without, val_without


# print()
# test_knapsack(knapsack_decision_tree)
# test_performance(knapsack_decision_tree, max_num_items=40)


def knapsack_indexed(items, capacity):
    return indexed_helper(items, 0, capacity)


def indexed_helper(items, index, capacity):
    ...  # TODO


# print()
# test_knapsack(knapsack_indexed)
# test_performance(knapsack_indexed, max_num_items=40)


############################################################
# memoize and look up overlapping subproblems
############################################################


def knapsack_indexed_memoized(items, capacity):
    memo = {}
    return memoized_helper(items, 0, capacity, memo)


def memoized_helper(items, index, capacity, memo):
    # bypass if already in memo
    if (index, capacity) in memo:
        return memo[index, capacity]

    # base case: no more items, empty solution
    if index == len(items):
        return [], 0

    # recursion: decide whether to take first item, solve for remaining items
    name, value, weight = items[index]

    sol_without, val_without = memoized_helper(items, index + 1, capacity, memo)

    if weight < capacity:
        sol_with, val_with = memoized_helper(items, index + 1, capacity - weight, memo)
        sol_with = sol_with + [name]  # need to make new list, due to memoization
        val_with += value
    else:
        sol_with = []
        val_with = 0

    # memoize answer before returning
    if val_with > val_without:
        answer = sol_with, val_with
    else:
        answer = sol_without, val_without
    memo[index, capacity] = answer
    return answer


# print()
# test_knapsack(knapsack_indexed_memoized)
# test_performance(knapsack_indexed_memoized, max_num_items=40)


############################################################
# tabular approach
############################################################


def knapsack_indexed_tabular(items, capacity):
    # initialize table of subproblem answers
    # rows are indexed by item index
    #     number of rows = len(items) + 1
    # columns are indexed by remaining capacity
    #     number of columns = capacity + 1
    #     (assumes capacity is an integer)
    n = len(items)
    table = [
        [ ([], 0) for _ in range(capacity + 1) ]
        for _ in range(n + 1)
    ]

    # base case is encoded in last row, where index == len(items)

    # fill in table from bottom row
    for i in range(n - 1, -1, -1):
        name, value, weight = items[i]

        for cap in range(capacity + 1):

            sol_without, val_without = table[i + 1][cap]

            if weight < cap:
                sol_with, val_with = table[i + 1][cap - weight]
                sol_with = sol_with + [name]
                val_with += value
            else:
                sol_with = []
                val_with = 0

            if val_with > val_without:
                table[i][cap] = sol_with, val_with
            else:
                table[i][cap] = sol_without, val_without

    # retrieve answer for original problem
    return table[0][capacity]


# print()
# test_knapsack(knapsack_indexed_tabular)
# test_performance(knapsack_indexed_tabular, max_num_items=40)


############################################################
# effectiveness of dynamic programming depends on amount of subproblem overlap
############################################################


def test_overlap(knapsack_func, num_items=40):
    print("=" * 40)
    print(f"Evaluate effect of subproblem overlap on DP for knapsack")
    print("=" * 40)
    print()

    base_weight_range = 1, 5
    base_capacity_range = 10, 20, 40, 80, 160, 320, 640, 1280
    for multiplier in 1, 10, 100:
        weight_range = [x * multiplier for x in base_weight_range]
        capacity_range = [x * multiplier for x in base_capacity_range]

        for cap in capacity_range:
            names, values, weights, capacity = make_random_knapsack(
                num_items=num_items,
                weight_range=weight_range,
                capacity=cap,
            )

            items = group_item_info(names, values, weights)
            start = time.time()
            solution = knapsack_func(items, capacity)
            stop = time.time()
            print(f"{num_items = },   {capacity = :>6},   duration = {stop - start:.2g} sec")

        print()


# test_overlap(knapsack_indexed_memoized)


def test_continuous_weights(knapsack_func):
    print("=" * 40)
    print(f"Evaluate effect of continuous weights on knapsack subproblem overlap")
    print("=" * 40)
    print()

    for num_items in range(1, 40):
        names, values, weights, capacity = make_random_knapsack(
            num_items=num_items, float_weights=True
        )
        items = group_item_info(names, values, weights)
        start = time.time()
        solution = knapsack_func(items, capacity)
        stop = time.time()
        print(f"{num_items = :>2},   {capacity = :>2},   duration = {stop - start:.2g} sec")

    print()


# test_continuous_weights(knapsack_indexed_memoized)
