import multiprocessing as mp
import time
from collections import Sequence

from luna.util.progress_tracker import ProgressData, ProgressTracker
from luna.util.file import new_unique_filename

import logging
logger = logging.getLogger()

MAX_NPROCS = mp.cpu_count() - 1

[docs]class Sentinel: """Custom sentinel to stop workers""" pass
[docs]class ArgsGenerator: """Custom generator that implements __len__(). This class can be used in conjunction with :class:`~luna.util.progress_tracker.ProgressTracker` in cases where the tasks are obtained from generators. Note that :class:`~luna.util.progress_tracker.ProgressTracker` requires a pre-defined number of tasks to calculate the progress, therefore a standard generator cannot be used directly as it does not implement __len__(). Then, with `ArgsGenerator`, one may take advantage of generators and :class:`~luna.util.progress_tracker.ProgressTracker` by explicitly providing the number of tasks that will be generated. Parameters ---------- generator : generator The tasks generator. nargs : int The number of tasks that will be generated. """ def __init__(self, generator, nargs): self.generator = generator self.nargs = nargs def __len__(self): return self.nargs def __iter__(self): for d in self.generator: yield d
[docs]class ParallelJobs: """Executes a set of tasks in parallel (:py:class:`~multiprocessing.JoinableQueue`) or sequentially. Parameters ---------- nproc : int or None The number of CPUs to use. The default value is the ``maximum number of CPUs - 1``. If ``nproc`` is None, 0, or 1, run the jobs sequentially. Otherwise, use the ``maximum number of CPUs - 1``. Attributes ---------- nproc : int The number of CPUs to use. progress_tracker : ProgressTracker A :class:`~luna.util.progress_tracker.ProgressTracker` object to track the tasks' progress. """ # TODO: add option to Threads/Multiprocessing def __init__(self, nproc=MAX_NPROCS): if nproc is not None: # Use 'MAX_NPROCS' if a non-integer has been provided. if not isinstance(nproc, int) or nproc < 0: nproc = MAX_NPROCS elif nproc in [0, 1]: nproc = None else: nproc = min(nproc, MAX_NPROCS) self.nproc = nproc self.progress_tracker = None def _exec_func(self, data, func): start = time.time() output = None exception = None try: # Unpack args as a dictionary. if isinstance(data, dict): output = func(**data) # Unpack args as a list. elif isinstance(data, Sequence): output = func(*data) # Try unpack the function as a single arg. else: output = func(data) # Capture any errors while executing the above code. except Exception as e: logger.exception(e) exception = e proc_time = time.time() - start return output, exception, proc_time def _producer(self, args, job_queue): for data in args: job_queue.put(data) def _consumer(self, func, job_queue, progress_queue, output_queue=None): while True: data = job_queue.get() # If sentinel is found, break. if isinstance(data, Sentinel): break # Execute the provided function. output, exception, proc_time = self._exec_func(data, func) if output is not None and output_queue is not None: output_queue.put((data, output)) # Update progress tracker. pd = ProgressData(input_data=data, output_data=output, exception=exception, proc_time=proc_time, func=func) progress_queue.put(pd) job_queue.task_done() def _saver(self, output_queue, output_file, proc_func=None, output_header=None): with open(output_file, "w") as OUT: if output_header is not None: OUT.write(output_header.strip()) OUT.write("\n") while True: data = output_queue.get() # If sentinel is found, break. if isinstance(data, Sentinel): break line = None if proc_func is not None: # Execute the provided function. output, exception, proc_time = self._exec_func(data, proc_func) line = output try: # If no data is stored in line, try to access the output generated by the _consumer() function. if line is None: line = data[1] OUT.write(str(line).strip()) OUT.write("\n") OUT.flush() except Exception as e: logger.error("An error occurred while trying to save the output '%s'." % str(line)) logger.exception(e) output_queue.task_done() def _sequential(self, args, func, progress_queue): # Run jobs sequentially. for data in args: # Execute provided function. output, exception, proc_time = self._exec_func(data, func) # Save data. pd = ProgressData(input_data=data, output_data=output, exception=exception, proc_time=proc_time) # Update progress tracker. progress_queue.put(pd)
[docs] def run_jobs(self, args, consumer_func, output_file=None, proc_output_func=None, output_header=None, job_name=None): """ Run a set of tasks in parallel or sequentially according to the ``nproc``. Parameters ---------- args : iterable of iterables, `ArgsGenerator` A sequence of arguments to be provided to the consumer function ``consumer_func``. consumer_func : function The function that will be executed for each set of arguments in ``args``. output_file : str, optional Save outputs to this file. If ``proc_output_func`` is not provided, it tries to save a stringified version of each output data. Otherwise, it executes ``proc_output_func`` first and its output will be printed to the output file instead. Note: if ``proc_output_func`` is provided but not ``output_file``, a new random unique filename will be generated and the file will be saved in the current directory. proc_output_func : function, optional Post-processing function that is executed for each output data produced by ``consumer_func``. output_header : str, optional A header for the output file. job_name : str, optional A name to identify the job. Returns ------- : :class:`~luna.util.progress_tracker.ProgressResult` """ if proc_output_func is not None and output_file is None: output_file = new_unique_filename(".") + ".output" logger.warning("No output file was defined. So, it will try to save results at '%s'." % output_file) elif output_file is not None: logger.warning("The output file '%s' was defined. So, it will try to save results at it." % output_file) # Queue for progress tracker. progress_queue = mp.JoinableQueue(maxsize=1) # Progress tracker self.progress_tracker = ProgressTracker(len(args), progress_queue, job_name) self.progress_tracker.start() # Initialize a new progress bar (display a 0% progress). progress_queue.put(None) if self.nproc is not None: job_queue = mp.JoinableQueue(maxsize=self.nproc) output_queue = None if output_file is not None: output_queue = mp.JoinableQueue() for i in range(self.nproc): p = mp.Process(name="ConsumerProcess-%d" % i, target=self._consumer, args=(consumer_func, job_queue, progress_queue, output_queue,)) p.daemon = True p.start() if output_file is not None: o = mp.Process(name="WriterProcess-%d" % i, target=self._saver, args=(output_queue, output_file, proc_output_func, output_header,)) o.daemon = True o.start() # Produce tasks to consumers. self._producer(args, job_queue) # Sentinels to stop consumers. sentinel = Sentinel() # Join all processes and add sentinels to stop consumers. job_queue.join() [job_queue.put(sentinel) for i in range(self.nproc)] if output_file is not None: output_queue.join() output_queue.put(sentinel) else: self._sequential(args, consumer_func, progress_queue) # Finish the progress tracker. self.progress_tracker.end() return self.progress_tracker.results