Source code hyperqueue/ffi/client.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import dataclasses
from typing import Dict, List, Optional, Sequence

from . import JobId, TaskId, ffi
from .protocol import JobDescription


class HqClientContext:
    """
    Opaque class returned from `connect_to_server`.
    Should be passed to FFI methods that require it.
    """


@dataclasses.dataclass(frozen=True)
class FailedTaskContext:
    error: str
    cwd: Optional[str]
    stdout: Optional[str]
    stderr: Optional[str]


TaskFailureMap = Dict[JobId, Dict[TaskId, FailedTaskContext]]


class ClientConnection:
    def __init__(self, directory: Optional[str] = None):
        self.ctx: HqClientContext = ffi.connect_to_server(directory)

    def submit_job(self, job_description: JobDescription) -> JobId:
        return ffi.submit_job(self.ctx, job_description)

    def wait_for_jobs(self, job_ids: Sequence[JobId], callback) -> List[JobId]:
        """Blocks until jobs are finished. Returns the number of failed tasks"""
        return ffi.wait_for_jobs(self.ctx, job_ids, callback)

    def stop_server(self):
        return ffi.stop_server(self.ctx)

    def get_failed_tasks(self, job_ids: Sequence[JobId]) -> TaskFailureMap:
        jobs = ffi.get_failed_tasks(self.ctx, job_ids)
        # TODO: use a class directly, keep in dicts or use TypedDict
        return {
            job_id: {
                task_id: FailedTaskContext(
                    error=data["error"],
                    cwd=data["cwd"],
                    stdout=data["stdout"],
                    stderr=data["stderr"],
                )
                for (task_id, data) in task_data.items()
            }
            for (job_id, task_data) in jobs.items()
        }