"""Tests for itertools exercises"""
import unittest

from iteration import (
    lstrip,
    compact,
    total_length,
    random_number_generator,
    running_mean,
    stop_on,
)


class LStripTests(unittest.TestCase):

    """Tests for lstrip."""

    def assertIterableEqual(self, iterable1, iterable2):
        self.assertEqual(list(iterable1), list(iterable2))

    def test_list(self):
        self.assertIterableEqual(lstrip([1, 1, 2, 3, 1], 1), [2, 3, 1])

    def test_nothing_to_strip(self):
        self.assertIterableEqual(lstrip([1, 2, 3], 0), [1, 2, 3])

    def test_string(self):
        self.assertIterableEqual(lstrip("  hello", " "), "hello")

    def test_empty_iterable(self):
        self.assertIterableEqual(lstrip([], 1), [])

    def test_strip_all(self):
        self.assertIterableEqual(lstrip([1, 1, 1], 1), [])

    def test_none_values(self):
        self.assertIterableEqual(lstrip([None, 1, 2, 3], 0), [None, 1, 2, 3])

    def test_iterator(self):
        squares = (n**2 for n in [0, 0, 1, 2, 3])
        self.assertIterableEqual(lstrip(squares, 0), [1, 4, 9])

    def test_returns_iterator(self):
        stripped = lstrip((1, 2, 3), 1)
        self.assertEqual(iter(stripped), iter(stripped))


class TotalLengthTests(unittest.TestCase):

    """Tests for total_length."""

    def test_list(self):
        self.assertEqual(total_length([1, 2, 3]), 3)

    def test_nothing(self):
        self.assertEqual(total_length(), 0)

    def test_iterators(self):
        self.assertEqual(total_length([1, 2, 3], [4, 5], iter([6, 7])), 7)


class CompactTests(unittest.TestCase):

    """Tests for compact."""

    def assertIterableEqual(self, iterable1, iterable2):
        self.assertEqual(list(iterable1), list(iterable2))

    def test_no_duplicates(self):
        self.assertIterableEqual(compact([1, 2, 3]), [1, 2, 3])

    def test_adjacent_duplicates(self):
        self.assertIterableEqual(compact([1, 1, 2, 2, 3]), [1, 2, 3])

    def test_non_adjacent_duplicates(self):
        self.assertIterableEqual(compact([1, 2, 3, 1, 2]), [1, 2, 3, 1, 2])

    def test_lots_of_adjacent_duplicates(self):
        self.assertIterableEqual(compact([1, 1, 1, 1, 1, 1]), [1])

    def test_empty_values(self):
        self.assertIterableEqual(compact([None, 0, "", []]), [None, 0, "", []])

    def test_empty_list(self):
        self.assertIterableEqual(compact([]), [])

    def test_accepts_iterator(self):
        nums = (n**2 for n in [1, 2, 3])
        self.assertIterableEqual(compact(nums), [1, 4, 9])

    def test_returns_iterator(self):
        nums = iter([1, 2, 3])
        output = compact(nums)
        self.assertEqual(iter(output), iter(output))
        self.assertEqual(next(output), 1)
        self.assertEqual(next(nums), 2)


class StopOnTests(unittest.TestCase):

    """Tests for stop_on."""

    def test_last_item(self):
        self.assertEqual(list(stop_on([1, 2, 3], 3)), [1, 2])

    def test_first_item(self):
        self.assertEqual(list(stop_on([1, 2, 3], 1)), [])

    def test_not_in(self):
        self.assertEqual(list(stop_on([1, 2, 3], 4)), [1, 2, 3])

    def test_repeats(self):
        self.assertEqual(list(stop_on([1, 1, 2, 2, 1, 2], 2)), [1, 1])


class RandomNumberGeneratorTests(unittest.TestCase):

    """Tests for random_number_generator."""

    def test_iterator(self):
        number_generator = random_number_generator()
        self.assertEqual(next(number_generator), 4)
        self.assertIs(iter(number_generator), number_generator)

    def test_generate_forever(self):
        number_generator = random_number_generator()
        output = [next(number_generator) for _ in range(999)]
        many_fours = [4] * 999
        self.assertEqual(output, many_fours)


class RunningMeanTests(unittest.TestCase):

    """Tests for running_mean."""

    def test_multiple_numbers(self):
        inputs = [8, 4, 3, 1, 3, 5]
        outputs = [8.0, 6.0, 5.0, 4.0, 3.8, 4.0]
        self.assertEqual(list(running_mean(inputs)), outputs)


if __name__ == "__main__":
    from helpers import error_message

    error_message()
