from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Awaitable, Generic, Generator, Iterable, \
  List, Tuple, TypeVar, cast

###

T = TypeVar('T')
R = TypeVar('R')


###

UNFULFILLED = object()


@dataclass
class Promise(Generic[T]):
    action: 'Action[T]'
    _result: Any = UNFULFILLED

    def complete(self, value: T) -> None:
        if self._result is not UNFULFILLED:
            raise Exception
        self._result = value

    @property
    def value(self) -> T:
        if self._result is UNFULFILLED:
            raise Exception
        return cast(T, self._result)


###

class Action(Awaitable[T]):
    def __await__(self) -> Generator[Any, None, T]:
        p = Promise(self)
        yield p
        return p.value

    @abstractmethod
    def match(self, interpreter: 'ActionInterpreter') -> T: ...


@dataclass
class AwaitableAction(Action[T]):
    coroutine: Awaitable[T]

    def __await__(self) -> Generator[Any, None, T]:
        return self.coroutine.__await__()

    def match(self, interpreter: 'ActionInterpreter') -> T:
        raise NotImplementedError


@dataclass
class LogAction(Action[None]):
    message: str

    def match(self, interpreter: 'ActionInterpreter') -> None:
        return interpreter.log(self)


@dataclass
class ParallelAction(Action[List[T]]):
    actions: Iterable[Action[T]]

    def match(self, interpreter: 'ActionInterpreter') -> List[T]:
        return interpreter.parallel(self)


StandardOut = str
StandardErr = str


@dataclass
class CommandAction(Action[Tuple[StandardOut, StandardErr]]):
    arguments: Tuple[str, ...]

    def match(self, interpreter: 'ActionInterpreter') -> Tuple[StandardOut, StandardErr]:
        return interpreter.command(self)


###

def log(message: str) -> Action[None]:
    return LogAction(message)


def parallel(*args: Awaitable[T]) -> Action[List[T]]:
    return ParallelAction(
        arg if isinstance(arg, Action) else AwaitableAction(arg)
        for arg in args
    )


def command(*args: str) -> Action[Tuple[StandardOut, StandardErr]]:
    return CommandAction(args)


###

class ActionInterpreter(metaclass=ABCMeta):
    @abstractmethod
    def log(self, action: LogAction) -> None: ...

    @abstractmethod
    def parallel(self, action: ParallelAction[T]) -> List[T]: ...

    @abstractmethod
    def command(self, action: CommandAction) -> Tuple[StandardOut, StandardErr]: ...

    def run(self, awaitable: Awaitable[T]) -> T:
        gen = cast(Generator[Promise[Any], None, T], awaitable.__await__())
        try:
            while True:
                promise = gen.send(None)
                promise.complete(promise.action.match(self))
        except StopIteration as stop:
            return cast(T, stop.value)


class DebugInterpreter(ActionInterpreter):
    def log(self, action: LogAction) -> None:
        print(action)

    def parallel(self, parent: ParallelAction[T]) -> List[T]:
        xs = []
        for child in parent.actions:
            # hint: fork interpreter
            xs.append(DebugInterpreter().run(child))
        return xs

    def command(self, action: CommandAction) -> Tuple[StandardOut, StandardErr]:
        print('running: ' + str(action.arguments))
        return '(stdout)', '(stderr)'


if TYPE_CHECKING:
    # concreteness proof
    DebugInterpreter()


###

async def what(x: int) -> None:
    await log(str(x))


async def hello(name: str) -> int:
    out, _ = await command('git', 'checkout', 'woof')
    await log('output: ' + out)
    await what(1)
    await log('hello ' + name)
    await parallel(
        what(2),
        what(3),
    )
    return 123


###

if __name__ == '__main__':
    print(DebugInterpreter().run(hello('what')))

###
