import random
import string
import time


############################################################
# knapsack problem specs
############################################################


def make_toy_knapsack(discrete=False):
    names = ["gold", "saffron", "diamond"]
    densities = [50, 60, 500 if discrete else 4]
    weights = [5, 4, 3]
    values = [density * weight for density, weight in zip(densities, weights)]
    capacity = 10
    return names, values, weights, densities, capacity


def make_names_list(num_items):
    roster = list(string.ascii_lowercase)
    for c in string.ascii_lowercase:
        roster += [c + d for d in string.ascii_lowercase]
    return roster[:num_items]


def make_random_knapsack(
    seed=5,
    num_items=10, weight_range=(1, 5), capacity=20,
    float_weights=False,
):
    random.seed(seed)
    names = make_names_list(num_items)
    densities = [random.randint(10, 100) for _ in range(num_items)]
    sampling_func = random.uniform if float_weights else random.randint
    weights = [sampling_func(*weight_range) for _ in range(num_items)]
    values = [density * weight for density, weight in zip(densities, weights)]
    return names, values, weights, densities, capacity


def group_item_info(names, values, weights, densities):
    assert len(names) == len(densities) == len(weights)
    n = len(names)
    return [
        (names[i], values[i], weights[i], densities[i])
        for i in range(n)
    ]


def item_names(items):
    return [item[0] for item in items]


def total_value(items):
    return sum([item[1] for item in items])


def total_weight(items):
    return sum([item[2] for item in items])


############################################################
# greedy approaches to continuous knapsack and 0/1 knapsack
############################################################


def knapsack_continuous(items, capacity):
    # sort by density
    ordered_items = sorted(items, key=lambda x: x[3], reverse=True)

    # collect items greedily
    solution = {}
    total_value = 0
    remaining = capacity
    for name, value, weight, density in ordered_items:
        if not remaining > 0:
            break
        amount = min(weight, remaining)
        solution[name] = amount
        total_value += amount * density
        remaining -= amount
    return solution, total_value


def test_knapsack_continuous():
    print("=" * 40)
    print("Continous knapsack scenarios")
    print("=" * 40)
    print()

    # toy example
    names, values, weights, densities, capacity = make_toy_knapsack()
    items = group_item_info(names, values, weights, densities)
    print(knapsack_continuous(items, capacity))
    print()

    # random example
    names, values, weights, densities, capacity = make_random_knapsack()
    print(f"{names =     }")
    print(f"{densities = }")
    print(f"{weights =   }")
    items = group_item_info(names, values, weights, densities)
    print(knapsack_continuous(items, capacity))
    print()


# test_knapsack_continuous()


def knapsack_discrete_greedy(items, capacity, metric="value"):
    # sort items by metric
    metric_funcs = {
        "value": lambda x: x[1],
        "weight": lambda x: x[2],
        "density": lambda x: x[3],
    }
    ordered_items = sorted(items, key=metric_funcs[metric], reverse=True)

    # collect items greedily
    solution = []
    total_value = 0
    remaining = capacity
    for name, value, weight, density in ordered_items:
        if not remaining > 0:
            break
        if weight < remaining:
            solution.append(name)
            total_value += value
            remaining -= weight
    return solution, total_value


def test_knapsack_discrete_greedy():
    print("=" * 40)
    print("Discrete knapsack scenarios, solved greedily")
    print("=" * 40)
    print()

    # toy example
    names, values, weights, densities, capacity = make_toy_knapsack(discrete=True)
    items = group_item_info(names, values, weights, densities)
    print(knapsack_discrete_greedy(items, capacity, metric="value"))
    print(knapsack_discrete_greedy(items, capacity, metric="weight"))
    print(knapsack_discrete_greedy(items, capacity, metric="density"))
    print()

    # random example
    names, values, weights, densities, capacity = make_random_knapsack(seed=5)
    print(f"{names   = }")
    print(f"{values  = }")
    print(f"{weights = }")
    items = group_item_info(names, values, weights, densities)
    print(knapsack_discrete_greedy(items, capacity, metric="value"))
    print(knapsack_discrete_greedy(items, capacity, metric="weight"))
    print(knapsack_discrete_greedy(items, capacity, metric="density"))
    print()


# test_knapsack_discrete_greedy()


############################################################
# exhaustive enumeration for optimal 0/1 knapsack solution
############################################################


def get_all_combos(items):
    # base case: empty list has a single empty combination
    if len(items) == 0:
        return [[]]

    # recursive case:
    # get combos of remaining elements, then include current element or not
    combos_without = get_all_combos(items[1:])
    combos_with = [
        [items[0]] + combo
        for combo in combos_without
    ]
    return combos_with + combos_without


# generating combos via binary digits
def get_all_combos_binary(items):
    n = len(items)
    combos = []
    for k in range(2**n):
        # option 1
        binrep = bin(k)[2:]
        selectors = "0" * (n - len(binrep)) + binrep

        # option 2
        # selectors = format(k, f"0{n}b")

        # option 3
        # selectors = f"{k:0{n}b}"

        # map binary selectors to selected items
        combos.append([items[i] for i in range(n) if selectors[i] == "1"])
    return combos

# get_all_combos = get_all_combos_binary


# print(get_all_combos(range(3)))


def knapsack_enumerate(items, capacity):
    # find best candidate out of all combinations
    best_candidate = None
    best_value = 0
    for selection in get_all_combos(items):
        if total_weight(selection) < capacity:
            value = total_value(selection)
            if value > best_value:
                best_candidate = item_names(selection)
                best_value = value
    return best_candidate, best_value


def knapsack_enumerate_comprehension(items, capacity):
    # list comprehension solution
    return max(
        [
            (item_names(selection), total_value(selection))
            for selection in get_all_combos(items)
            if total_weight(selection) < capacity
        ],
        key=lambda x: x[1],
    )


def test_knapsack_discrete_optimal(knapsack_func):
    print("=" * 40)
    print("Discrete knapsack scenarios, solved optimally")
    print("=" * 40)
    print()

    # toy example
    names, values, weights, densities, capacity = make_toy_knapsack(discrete=True)
    items = group_item_info(names, values, weights, densities)
    print(knapsack_func(items, capacity))
    print()

    # random example
    names, values, weights, densities, capacity = make_random_knapsack(seed=5)
    print(f"{names   = }")
    print(f"{values  = }")
    print(f"{weights = }")
    items = group_item_info(names, values, weights, densities)
    print(knapsack_func(items, capacity))
    print()


# test_knapsack_discrete_optimal(knapsack_enumerate)
# test_knapsack_discrete_optimal(knapsack_enumerate_comprehension)


def test_greedy_vs_optimal():
    print("=" * 40)
    print("Random discrete knapsack scenarios, greedy vs optimal solutions")
    print("=" * 40)
    print()

    for seed in range(5, 11):
        names, values, weights, densities, capacity = make_random_knapsack(seed=seed)
        print(f"{seed = }")
        print(f"{names   = }")
        print(f"{values  = }")
        print(f"{weights = }")

        items = group_item_info(names, values, weights, densities)
        print("greedy value:  ", knapsack_discrete_greedy(items, capacity, metric="value"))
        print("greedy weight: ", knapsack_discrete_greedy(items, capacity, metric="weight"))
        print("greedy density:", knapsack_discrete_greedy(items, capacity, metric="density"))
        print("optimal:       ", knapsack_enumerate(items, capacity))
        print()


# test_greedy_vs_optimal()


def test_performance(knapsack_func, upper_limit=30):
    print("=" * 40)
    print(f"Optimally solving random discrete knapsack scenarios, sizes 1 to {upper_limit}")
    print("=" * 40)
    print()

    for size in range(1, upper_limit + 1):
        names, values, weights, densities, capacity = make_random_knapsack(
            num_items=size,
        )
        print(f"{size    = }")
        print(f"{names   = }")
        print(f"{values  = }")
        print(f"{weights = }")

        items = group_item_info(names, values, weights, densities)
        start = time.time()
        solution = knapsack_func(items, capacity)
        stop = time.time()
        print(f"{solution = }")
        print(f"duration = {stop - start : .2g} sec")
        print()


# test_performance(knapsack_enumerate)


############################################################
# prune infeasible branches in decision tree
############################################################


def knapsack_decision_tree(items, capacity):
    # base case: empty items has empty solution
    if len(items) == 0:
        return [], 0

    # recursion: decide whether to take first item, solve for remaining items
    item, remaining = items[0], items[1:]
    name, value, weight, _ = item

    # consider branch where we don't take item[0]
    sol_without, val_without = knapsack_decision_tree(remaining, capacity)

    # consider branch where wer do take item[0]
    if weight < capacity:
        sol_with, val_with = knapsack_decision_tree(remaining, capacity - weight)
        sol_with.append(name)
        val_with += value
    else:
        sol_with = []
        val_with = 0

    # determine which branch is better
    if val_with > val_without:
        return sol_with, val_with
    else:
        return sol_without, val_without


# test_knapsack_discrete_optimal(knapsack_decision_tree)
# test_performance(knapsack_decision_tree, upper_limit=40)


def knapsack_indexed(items, capacity):
    return knapsack_indexed_helper(items, 0, capacity)


def knapsack_indexed_helper(items, index, capacity):
    ...


# test_knapsack_discrete_optimal(knapsack_indexed)
# test_performance(knapsack_indexed, upper_limit=40)


############################################################
# memoize and look up overlapping subproblems
############################################################


def knapsack_indexed_memoized(items, capacity):
    memo = {}
    return kim_helper(items, 0, capacity, memo)


def kim_helper(items, index, capacity, memo):
    # bypass if already in memo
    if (index, capacity) in memo:
        return memo[index, capacity]

    # base case: no more items, empty solution
    if index == len(items):
        return [], 0

    # recursion: decide whether to take first item, solve for remaining items
    name, value, weight, _ = items[index]

    sol_without, val_without = kim_helper(items, index + 1, capacity, memo)

    if weight < capacity:
        sol_with, val_with = kim_helper(items, index + 1, capacity - weight, memo)
        sol_with = sol_with + [name]  # need to make new list, due to memoization
        val_with += value
    else:
        sol_with = []
        val_with = 0

    if val_with > val_without:
        answer = sol_with, val_with
    else:
        answer = sol_without, val_without
    memo[index, capacity] = answer
    return answer


# test_knapsack_discrete_optimal(knapsack_indexed_memoized)
# test_performance(knapsack_indexed_memoized, upper_limit=40)


############################################################
# tabular approach
############################################################


def knapsack_indexed_tabular(items, capacity):
    # initialize table of subproblem answers
    # rows are indexed by item index, (len(items) + 1) rows
    # columns are indexed by remaining capacity, (capacity + 1) columns
    # (assumes) capacity is an integer
    n = len(items)
    table = [
        [ ([], 0) for _ in range(capacity + 1) ]
        for _ in range(n + 1)
    ]

    # base case is encoded in last row, where index == len(items)

    # fill in table from bottom row
    for i in range(n - 1, -1, -1):
        name, value, weight, _ = items[i]

        for cap in range(capacity + 1):

            sol_without, val_without = table[i + 1][cap]

            if weight < cap:
                sol_with, val_with = table[i + 1][cap - weight]
                sol_with = sol_with + [name]
                val_with += value
            else:
                sol_with = []
                val_with = 0

            if val_with > val_without:
                table[i][cap] = sol_with, val_with
            else:
                table[i][cap] = sol_without, val_without

    # retrieve answer for original problem
    return table[0][capacity]


# test_knapsack_discrete_optimal(knapsack_indexed_tabular)
# test_performance(knapsack_indexed_tabular, upper_limit=40)


def knapsack_indexed_tabular_values(items, capacity):
    # initialize table of subproblem VALUES
    n = len(items)
    table = [[0 for _ in range(capacity + 1)] for _ in range(n + 1)]

    # base case is encoded in last row, where index == len(items)

    # fill in table with VALUES only
    for i in range(n - 1, -1, -1):
        name, value, weight, _ = items[i]

        for cap in range(capacity + 1):

            val_without = table[i + 1][cap]

            if weight < cap:
                val_with = table[i + 1][cap - weight] + value
            else:
                val_with = 0

            if val_with > val_without:
                table[i][cap] = val_with
            else:
                table[i][cap] = val_without

    # trace through values to find items in the optimal solution
    solution = []
    total_value = table[0][capacity]

    remaining = capacity
    for i in range(0, n):
        # values stay the same down one column only if optimal NOT to take an item
        if table[i][remaining] == table[i + 1][remaining]:
            continue
        else:
            name, value, weight, _ = items[i]
            solution.append(name)
            remaining -= weight

    return solution, total_value


# test_knapsack_discrete_optimal(knapsack_indexed_tabular_values)
# test_performance(knapsack_indexed_tabular_values, upper_limit=40)


############################################################
# effectiveness of dynamic programming depends on amount of subproblem overlap
############################################################


def test_overlap(knapsack_func, num_items=40):
    print("=" * 40)
    print(f"Evaluate effect of subproblem overlap on DP for knapsack")
    print("=" * 40)
    print()

    base_weight_range = (1, 5)
    base_capacity_range = (10, 20, 40, 80, 160, 320, 640, 1280)

    for multiplier in (1, 10, 100):
        weight_range = [x * multiplier for x in base_weight_range]
        capacity_range = [x * multiplier for x in base_capacity_range]

        for cap in capacity_range:
            names, values, weights, densities, capacity = make_random_knapsack(
                num_items=num_items,
                weight_range=weight_range,
                capacity=cap,
            )

            items = group_item_info(names, values, weights, densities)
            start = time.time()
            solution = knapsack_func(items, capacity)
            stop = time.time()

            print(f"{num_items = }, {capacity = :>6}, duration = {stop - start:.2g} sec")
        print()


# test_overlap(knapsack_indexed_tabular)


def test_continuous_weights(knapsack_func):
    print("=" * 40)
    print(f"Evaluate effect of subproblem overlap on DP for knapsack")
    print("=" * 40)
    print()

    for num_items in range(1, 40):
        names, values, weights, densities, capacity = make_random_knapsack(
            num_items=num_items, float_weights=True
        )
        items = group_item_info(names, values, weights, densities)
        start = time.time()
        solution = knapsack_func(items, capacity)
        stop = time.time()
        print(f"{num_items = :>2}, {capacity = :>2}, duration = {stop - start:.2g} sec")
    print()


# test_continuous_weights(knapsack_indexed_memoized)
