Source code hyperqueue/task/program.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
from typing import Dict, List, Optional, Sequence, Union

from ..common import GenericPath
from ..ffi import TaskId
from ..ffi.protocol import ResourceRequest, TaskDescription
from ..output import Output, gather_outputs
from ..validation import ValidationException, validate_args
from .task import EnvType, Task

ProgramArgs = Union[List[str], str]


class ExternalProgram(Task):
    """
    Task that represents the execution of an executable binary.
    """

    def __init__(
        self,
        task_id: TaskId,
        *,
        args: List[str],
        env: Optional[EnvType] = None,
        cwd: Optional[GenericPath] = None,
        stdout: Optional[GenericPath] = None,
        stderr: Optional[GenericPath] = None,
        stdin: Optional[Union[str, bytes]] = None,
        name: Optional[str] = None,
        dependencies: Sequence[Task] = (),
        task_dir: bool = False,
        priority: int = 0,
        resources: Optional[ResourceRequest],
    ):
        super().__init__(
            task_id,
            dependencies,
            priority,
            resources,
            env=env,
            cwd=cwd,
            stdout=stdout,
            stderr=stderr,
            name=name,
        )
        args = to_arg_list(args)
        validate_args(args)
        self.args = args
        self.task_dir = task_dir

        if stdin is None or isinstance(stdin, bytes):
            self.stdin = stdin
        elif isinstance(stdin, str):
            self.stdin = stdin.encode()
        else:
            raise Exception("stdin has to be str, bytes, or None")

        self.outputs = get_task_outputs(self)

    def _build(self, client):
        depends_on = [dependency.task_id for dependency in self.dependencies]
        return TaskDescription(
            id=self.task_id,
            args=self.args,
            env=self.env,
            stdout=self.stdout,
            stderr=self.stderr,
            stdin=self.stdin,
            cwd=self.cwd,
            dependencies=depends_on,
            task_dir=self.task_dir,
            priority=self.priority,
            resource_request=self.resources,
        )

    def __getitem__(self, key: str):
        if key not in self.outputs:
            raise Exception(f"Output `{key}` not found in {self}")
        return self.outputs[key]

    def __repr__(self):
        return f"Task(args={self.args}, env={self.env}, cwd={self.cwd}, outputs={self.outputs})"


def to_arg_list(args: ProgramArgs) -> List[str]:
    if isinstance(args, str):
        return [args]
    return args


def get_task_outputs(task: ExternalProgram) -> Dict[str, Output]:
    # TODO: outputs in cwd
    # TODO: multiple outputs with the same name, but different parameters
    output_map = {}

    outputs = gather_outputs(task.args) + gather_outputs(task.env)
    for output in outputs:
        if output.name in output_map:
            raise ValidationException(
                f"Output `{output.name}` has been defined multiple times"
            )
        output_map[output.name] = output
    return output_map