import math
import random


############################################################
# lambda functions
############################################################


def bisection_square_root(x, epsilon=1e-3):
    lower, upper = 0, x
    mid = (lower + upper) / 2
    while upper - lower > epsilon:
        if mid ** 2 > x:
            upper = mid
        else:
            lower = mid
        mid = (lower + upper) / 2
    return mid


def test_bisection_square_root(precision=1e-3):
    queries = [2, 3, 4, 9, 25, 30.25, 50]
    for q in queries:
        result = bisection_square_root(q, precision)
        error = abs(result - math.sqrt(q))
        print(f"{q:5.2f}  {result:.20f}  {error:.20f}")


print()
test_bisection_square_root()
print()
test_bisection_square_root(precision=1e-6)


def bisection_search(x, inverse, epsilon=1e-3):
    lower, upper = 0, x
    mid = (lower + upper) / 2
    while upper - lower > epsilon:
        if inverse(mid) > x:
            upper = mid
        else:
            lower = mid
        mid = (lower + upper) / 2
    return mid


def cube(x):
    return x**3


def exp2(x):
    return 2**x


# print()
# print(bisection_search(8, inverse=cube))
# print(bisection_search(27, inverse=cube))
# print(bisection_search(64, inverse=cube))
# print()
# print(bisection_search(16, inverse=exp2))
# print(bisection_search(32, inverse=exp2))
# print(bisection_search(64, inverse=exp2))

# print()
# print(bisection_search(8, lambda x: x**3))
# print(bisection_search(27, lambda x: x**3))
# print(bisection_search(64, lambda x: x**3))
# print()
# print(bisection_search(16, lambda x: 2**x))
# print(bisection_search(32, lambda x: 2**x))
# print(bisection_search(64, lambda x: 2**x))


def demo_lambda_keys():
    random.seed(8)
    triples = []
    for _ in range(5):
        x = random.randint(1, 9)
        y = random.randint(1, 9)
        z = random.randint(1, 9)
        triples.append((x, y, z))
    print(triples)

    print()
    print(min(triples))
    print(min(triples, key=lambda x: x[1]))
    # print(min(triples, key=lambda x: -x[0] * x[-1]))
    # print(min(triples, key=lambda x: sum(x)))
    # print(min(triples, key=sum))

    # print()
    # print(sorted(triples))
    # print(sorted(triples, key=lambda x: x[1]))
    # print(sorted(triples, key=lambda x: -x[0] * x[-1]))
    # print(sorted(triples, key=lambda x: sum(x)))
    # print(sorted(triples, key=sum))


# print()
# demo_lambda_keys()


# without key function: version given on exam 1
def break_cipher(ciphertext, alpha):
    results = []
    for magic in range(1, len(alpha)):
        for shift in range(len(alpha)):
            plaintext = decrypt(ciphertext, alpha, shift, magic)
            num_words = count_words(plaintext)
            results.append([num_words, plaintext])
    _, answer = max(results)
    return answer


# with key function: direct use of max() on plaintexts without proxy tuples
def break_cipher(ciphertext, alpha):
    results = []
    for magic in range(1, len(alpha)):
        for shift in range(len(alpha)):
            plaintext = decrypt(ciphertext, alpha, shift, magic)
            results.append(plaintext)
    return max(results, key=count_words)


# EXERCISE: replace the break_cipher() function in your Pset 1 with the
# above, and make sure it still works


def demo_conditional_expr():
    random.seed(0)
    a, z = [], []
    for _ in range(10):
        a.append(random.randint(100, 999))
        z.append(random.randint(100, 999))
    print(a)
    print(z)

    # picks = []
    # for i in range(10):
    #     if a[i] > z[i]:
    #         picks.append("A")
    #     else:
    #         picks.append("Z")
    # print(picks)

    # picks = []
    # for i in range(10):
    #     picks.append("A" if a[i] > z[i] else "Z")
    # print(picks)


# print()
# demo_conditional_expr()


def compare_adjacents(seq, compare):
    result = []
    for i in range(1, len(seq)):
        result.append(compare(seq[i - 1], seq[i]))
    return result


def demo_conditional_lambda():
    random.seed(0)
    seq = []
    for _ in range(10):
        seq.append(random.randint(1, 9))
    print(seq)

    # rises = []
    # for i in range(1, len(seq)):
    #     if seq[i] > seq[i - 1]:
    #         rises.append(seq[i] - seq[i - 1])
    #     else:
    #         rises.append(0)
    # print(rises)

    # print(compare_adjacents(seq, lambda x, y: y - x if x < y else 0))
    # print(compare_adjacents(seq, lambda x, y: x - y if x > y else 0))


# print()
# demo_conditional_lambda()


############################################################
# comprehensions
############################################################


def list_construction():
    original = range(10, 20)

    double = []
    for num in original:
        double.append(num * 2)
    print(double)

    a, b, c = 1, -2, 1
    yvals = []
    for num in original:
        yvals.append(a * num**2 + b * num + c)
    print(yvals)


# print()
# list_construction()


def make_list(collection, transform):
    result = []
    for item in collection:
        result.append(transform(item))
    return result


def list_construction_abstracted():
    original = range(10, 20)
    print(make_list(original, lambda x: x * 2))
    a, b, c = 1, -2, 1
    print(make_list(original, lambda x: a * x**2 + b * x + c))


# print()
# list_construction_abstracted()


def list_construction_comprehension():
    original = range(10, 20)
    print([x * 2 for x in original])
    a, b, c = 1, -2, 1
    print([a * x**2 + b * x + c for x in original])


# print()
# list_construction_comprehension()


def is_prime(num):
    if num == 1:
        return False
    if num == 2:
        return True
    for x in range(2, num // 2 + 1):
        if num % x == 0:
            return False
    return True


def make_list_conditional(collection, transform, test=lambda x: True):
    result = []
    for item in collection:
        if test(item):
            result.append(transform(item))
    return result


def list_construction_conditional():
    original = range(1, 100)

    print()
    print(make_list_conditional(
        original,
        transform=lambda x: x * 2,
        test=lambda x: x % 2 == 0,
    ))
    print(make_list_conditional(
        original,
        transform=lambda x: x * 2,
        test=is_prime,
    ))

    # print()
    # print([x * 2 for x in original if x % 2 == 0])
    # print([x * 2 for x in original if is_prime(x)])

    # print()
    # print(make_list_conditional(
    #     original,
    #     transform=lambda x: x,
    #     test=lambda x: x % 10 == 7,
    # ))
    # print(make_list_conditional(
    #     original,
    #     transform=is_prime,
    #     test=lambda x: x % 10 == 7,
    # ))

    # print()
    # print([x for x in original if x % 10 == 7])
    # print([is_prime(x) for x in original if x % 10 == 7])
    # print(sum([is_prime(x) for x in original if x % 10 == 7]))
    # print(len([x for x in original if x % 10 == 7 and is_prime(x)]))


# print()
# list_construction_conditional()


def sum_digits(num):
    total = 0
    for c in str(num):
        total += int(c)
    return total

    # EXERCISE: rewrite using a list comprehension


def other_comprehensions():
    print()
    print({sum_digits(num) for num in range(101, 150)})
    print()
    print({num: sum_digits(num) for num in range(101, 150)})
    print()
    print({sum_digits(num): num for num in range(101, 150)})

    print()

    # EXERCISE: within a range of positive integers, arrange each of
    # their digits in ascending order, and list all unique such
    # sequences in ascending order
    print()
    print([str(num) for num in range(101, 150)])
    ...


# print()
# other_comprehensions()


# EXERCISE: rewrite break_cipher() using a list comprehension
def break_cipher(ciphertext, alpha):
    return max(...)
