"""Test helpers"""
from contextlib import contextmanager, redirect_stdout, redirect_stderr
from pathlib import Path
from tempfile import NamedTemporaryFile
from importlib.machinery import SourceFileLoader
from io import StringIO
import sys
from unittest.mock import patch
import unittest


DIR_PATH = Path(__file__).parent


class DummyException(BaseException):
    """This should never be raised."""


def run_program(program, args=[], raises=DummyException, stderr=False):
    old_args = sys.argv
    assert all(isinstance(a, str) for a in args)
    try:
        module_path = DIR_PATH / 'modules' / program
        sys.argv = [program] + args
        with redirect_stdout(StringIO()) as output:
            error = StringIO() if stderr else output
            with redirect_stderr(error):
                try:
                    if '__main__' in sys.modules:
                        del sys.modules['__main__']
                    SourceFileLoader('__main__', str(module_path)).load_module()
                except SystemExit as e:
                    if raises == SystemExit:
                        assert e.args != (0,), "Non-zero exit code expected"
                    elif e.args != (0,):
                        raise SystemExit(error.getvalue()) from e
                except raises:
                    return error.getvalue()
                if raises is not DummyException and raises != SystemExit:
                    raise AssertionError("{} not raised".format(raises))
        if stderr:
            return output.getvalue(), error.getvalue()
        else:
            return output.getvalue()
    finally:
        sys.argv = old_args


def import_module(module):
    """Import a module from a path."""
    path = 'modules/{module}.py'.format(module=module)
    return SourceFileLoader(module, path).load_module()




@contextmanager
def make_file(contents=None):
    with NamedTemporaryFile(mode='wt', delete=False) as f:
        if contents:
            f.write(contents)
    try:
        yield f.name
    finally:
        Path(f.name).unlink(missing_ok=True)


@contextmanager
def capture_stdin(data):
    old_stdin, sys.stdin = sys.stdin, StringIO()
    try:
        sys.stdin.write(data)
        sys.stdin.seek(0)
        yield sys.stdin
    finally:
        sys.stdin = old_stdin


def error_message():
    print("Cannot run {} from the command-line.".format(sys.argv[0]))
    print()
    print("Run python test.py <your_exercise_name> instead")


class BaseTestCase(unittest.TestCase):

    """TestCase for non-script tests."""

    def setUp(self):
        def fake_input(prompt=None):
            raise SystemError(
                "This function should not use input()."
                " Use function arguments instead."
            )
        self.input_patch = patch("builtins.input", fake_input)
        self.input_patch.start()

    def tearDown(self):
        self.input_patch.stop()


class ModuleTestCase(unittest.TestCase):

    """TestCase for module/program tests."""

    @classmethod
    def setUpClass(cls):
        if not hasattr(cls, 'module_path'):
            raise NotImplementedError('Test needs "module_path" attribute')
        program = cls.module_path
        if not (DIR_PATH / 'modules' / program).is_file():
            raise ValueError(
                ("You need to make a file called {program} in the "
                 "modules subdirectory.").format(program=program)
            )
