Source code hyperqueue/task/program.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
from typing import Dict, List, Optional, Sequence, Union

from ..common import GenericPath
from ..ffi import TaskId
from ..ffi.protocol import ResourceRequest, TaskDescription
from ..output import Output, Stdio, gather_outputs
from ..validation import ValidationException, validate_args
from .task import EnvType, Task, _make_ffi_requests

ProgramArgs = Union[List[str], str]


class ExternalProgram(Task):
    """
    Task that represents the execution of an executable binary.
    """

    def __init__(
        self,
        task_id: TaskId,
        *,
        args: List[str],
        env: Optional[EnvType] = None,
        cwd: Optional[GenericPath] = None,
        stdout: Optional[Stdio] = None,
        stderr: Optional[Stdio] = None,
        stdin: Optional[Union[str, bytes]] = None,
        name: Optional[str] = None,
        dependencies: Sequence[Task] = (),
        task_dir: bool = False,
        priority: int = 0,
        resources: Optional[ResourceRequest],
        crash_limit: Optional[int] = None,
    ):
        super().__init__(
            task_id,
            dependencies,
            priority,
            resources,
            env=env,
            cwd=cwd,
            stdout=stdout,
            stderr=stderr,
            name=name,
            crash_limit=crash_limit,
        )
        args = to_arg_list(args)
        validate_args(args)
        self.args = args
        self.task_dir = task_dir

        if stdin is None or isinstance(stdin, bytes):
            self.stdin = stdin
        elif isinstance(stdin, str):
            self.stdin = stdin.encode()
        else:
            raise Exception("stdin has to be str, bytes, or None")

        self.outputs = get_task_outputs(self)

    def _build(self, client):
        depends_on = [dependency.task_id for dependency in self.dependencies]
        return TaskDescription(
            id=self.task_id,
            args=self.args,
            env=self.env,
            stdout=self.stdout,
            stderr=self.stderr,
            stdin=self.stdin,
            cwd=self.cwd,
            dependencies=depends_on,
            task_dir=self.task_dir,
            priority=self.priority,
            resource_request=_make_ffi_requests(self.resources),
            crash_limit=self.crash_limit,
        )

    def __getitem__(self, key: str):
        if key not in self.outputs:
            raise Exception(f"Output `{key}` not found in {self}")
        return self.outputs[key]

    def __repr__(self):
        return f"Task(args={self.args}, env={self.env}, cwd={self.cwd}, outputs={self.outputs})"


def to_arg_list(args: ProgramArgs) -> List[str]:
    if isinstance(args, str):
        return [args]
    return args


def get_task_outputs(task: ExternalProgram) -> Dict[str, Output]:
    # TODO: outputs in cwd
    # TODO: multiple outputs with the same name, but different parameters
    output_map = {}

    outputs = gather_outputs(task.args) + gather_outputs(task.env)
    for output in outputs:
        if output.name in output_map:
            raise ValidationException(f"Output `{output.name}` has been defined multiple times")
        output_map[output.name] = output
    return output_map