Source code for sparklanes._framework.lane

"""Lane and Branch classes. TODO: Better logging"""
from importlib import import_module
from inspect import isclass

import yaml
from schema import SchemaError

from six import string_types

from sparklanes._framework.errors import LaneSchemaError, LaneImportError
from sparklanes._framework.validation import validate_schema

from .env import INTERNAL_LOGGER_NAME
from .errors import LaneExecutionError
from .log import make_default_logger
from .task import LaneTask, LaneTaskThread
from .validation import validate_params

class Lane(object):
    """Used to build and run data processing lanes (i.e. pipelines).
    Public methods are chainable."""

    def __init__(self, name='UnnamedLane', run_parallel=False):
        name : str
            Custom name of the lane
        run_parallel : bool
            Indicates, whether the tasks in a Lane shall be executed in parallel.
            Does not affect branches inside the lane (`run_parallel` must be indicated in the
            branches themselves)
        if not isinstance(name, string_types):
            raise TypeError('`name` must be a string') = name
        self.run_parallel = run_parallel

        self.tasks = []

    def __str__(self):
        """Generates a readable string using the tasks/branches inside the lane, i.e. builds a
        string the showing the tasks and branches in a lane"""
        task_str = '=' * 80 + '\n'

        def generate_str(lane_or_branch, prefix='\t', out=''):
            """Recursive string generation"""
            out += prefix +
            if lane_or_branch.run_parallel:
                out += ' (parallel)'
            out += '\n'

            for task in lane_or_branch.tasks:
                if isinstance(task, Branch):
                    out += generate_str(task, prefix + prefix[0])
                elif isinstance(task['cls_or_branch'], Branch):
                    out += generate_str(task['cls_or_branch'], prefix + prefix[0])
                    out += prefix + ' >' + task['cls_or_branch'].__name__ + '\n'

            return out

        task_str += generate_str(self) + '=' * 80

        return task_str

    def __validate_task(self, cls, entry_mtd_name, args, kwargs):
        """Checks if a class is a task, i.e. if it has been decorated with `sparklanes.Task`, and if
        the supplied args/kwargs match the signature of the task's entry method.

        cls : LaneTask
        entry_mtd_name : str
            Name of the method, which is called when the task is run
        args : list
        kwargs : dict
        if not isclass(cls) or not issubclass(cls, LaneTask):
            raise TypeError('Tried to add non-Task `%s` to a Lane. Are you sure the task was '
                            'decorated with `sparklanes.Task`?' % str(cls))

        validate_params(cls, entry_mtd_name, *args, **kwargs)

[docs] def add(self, cls_or_branch, *args, **kwargs): """Adds a task or branch to the lane. Parameters ---------- cls_or_branch : Class *args Variable length argument list to be passed to `cls_or_branch` during instantiation **kwargs Variable length keyword arguments to be passed to `cls_or_branch` during instantiation Returns ------- self: Returns `self` to allow method chaining """ if isinstance(cls_or_branch, Branch): self.tasks.append(cls_or_branch) # Add branch with already validated tasks else: # Validate self.__validate_task(cls_or_branch, '__init__', args, kwargs) # Append self.tasks.append({'cls_or_branch': cls_or_branch, 'args': args, 'kwargs': kwargs}) return self
[docs] def run(self): """Executes the tasks in the lane in the order in which they have been added, unless `self.run_parallel` is True, then a thread is spawned for each task and executed in parallel (note that task threads are still spawned in the order in which they were added). """ logger = make_default_logger(INTERNAL_LOGGER_NAME)'\n%s\nExecuting `%s`\n%s\n', '-'*80,, '-'*80)'\n%s', str(self)) threads = [] if not self.tasks: raise LaneExecutionError('No tasks to execute!') for task_def_or_branch in self.tasks: if isinstance(task_def_or_branch, Branch): elif isinstance(task_def_or_branch['cls_or_branch'], Branch): # Nested Branch task_def_or_branch['cls_or_branch'].run() else: task = task_def_or_branch['cls_or_branch'](*task_def_or_branch['args'], **task_def_or_branch['kwargs']) if self.run_parallel: threads.append(LaneTaskThread(task)) else: task() if threads: for thread in threads: thread.start() for thread in threads: thread.join()'\n%s\nFinished executing `%s`\n%s', '-'*80,, '-'*80) return self
class Branch(Lane, object): """Branches can be used to split task lanes into branches, which is e.g. useful if part of the data processing pipeline should be executed in parallel, while other parts should be run in subsequent order.""" def __init__(self, name='UnnamedBranch', run_parallel=False): """ Parameters ---------- name (str): Custom name of the branch run_parallel (bool): Indicates if the task in the branch shall be executed in parallel args (List[object]) """ super(Branch, self).__init__(name=name, run_parallel=run_parallel) def build_lane_from_yaml(path): """Builds a `sparklanes.Lane` object from a YAML definition file. Parameters ---------- path: str Path to the YAML definition file Returns ------- Lane Lane, built according to definition in YAML file """ # Open with open(path, 'rb') as yaml_definition: definition = yaml.load(yaml_definition) # Validate schema try: validate_schema(definition) except SchemaError as exc: raise LaneSchemaError(**exc.__dict__) def build(lb_def, branch=False): """Function to recursively build the `sparklanes.Lane` object from a YAML definition""" init_kwargs = {k: lb_def[k] for k in (a for a in ('run_parallel', 'name') if a in lb_def)} lane_or_branch = Lane(**init_kwargs) if not branch else Branch(**init_kwargs) for task in lb_def['tasks']: if 'branch' in task: branch_def = task['branch'] lane_or_branch.add(build(branch_def, True)) else: sep = task['class'].rfind('.') if sep == -1: raise LaneImportError('Class must include its parent module') mdl = task['class'][:sep] cls_ = task['class'][sep + 1:] try: cls = getattr(import_module(mdl), cls_) except ImportError: raise LaneImportError('Could not find module %s' % mdl) except AttributeError: raise LaneImportError('Could not find class %s' % cls_) args = task['args'] if 'args' in task else [] args = [args] if not isinstance(args, list) else args kwargs = task['kwargs'] if 'kwargs' in task else {} lane_or_branch.add(cls, *args, **kwargs) return lane_or_branch return build(definition['lane'])