Source code for secml.parallel.parfor

from multiprocessing import Pool, cpu_count
from joblib import Parallel, delayed


[docs]def parfor(task, processes, args): """Parallel For. Applies a function *task* to each argument in *args*, using a pool of concurrent processes. Parameters ---------- task : function Function object that should process each element in args. processes : int Maximum number of concurrent processes to be used in the pool. If higher than multiprocessing.cpu_count(), all processor's cores will be used. args : any Iterable object, where each element is an argument for task. Returns ------- out : iterable Iterable object containing the output of task(arg) for each arg in args. """ # Don't try to spawn more processes than available CPUs num_cores = min(cpu_count(), processes) pool = Pool(processes=num_cores) return pool.map(task, args)
[docs]def parfor2(task, n_reps, processes, *args): """Parallel For. Run function `task` using each argument in `args` as input, using a pool of concurrent processes. The `task` should take as first input the index of parfor iteration. Parameters ---------- task : function Function object that should process each element in `args`. n_reps : int Number of times the `task` should be run. processes : int Maximum number of concurrent processes to be used in the pool. If higher than `multiprocessing.cpu_count()`, all processor's cores will be used. args : any, optional Tuple with input arguments for `task`. Returns ------- out : list List with iteration output, sorted (rep1, rep2, ..., repN). """ # Don't try to spawn more processes than available CPUs num_cores = min(cpu_count(), processes) return Parallel(n_jobs=num_cores, backend='multiprocessing')( delayed(task)(i, *args) for i in range(n_reps))
if __name__ == "__main__": from math import factorial arguments = range(10) res = [factorial(z) for z in arguments] parres = parfor(factorial, 2, arguments) print(parres) def element_wise_power(idx, list_of_scalars): print("Repetition {:} started...".format(idx)) list_of_scalars_pow = [] for obj_idx, obj in enumerate(list_of_scalars): list_of_scalars_pow.append(list_of_scalars[obj_idx]**idx) print("Repetition {:} ended...".format(idx)) return list_of_scalars_pow parout = parfor2(element_wise_power, 4, 2, ([j for j in range(10)])) print(parout)