|
import multiprocessing |
|
import queue as Queue |
|
import threading |
|
import time |
|
|
|
|
|
class SubprocessGenerator(object): |
|
|
|
@staticmethod |
|
def launch_thread(generator): |
|
generator._start() |
|
|
|
@staticmethod |
|
def start_in_parallel( generator_list ): |
|
""" |
|
Start list of generators in parallel |
|
""" |
|
for generator in generator_list: |
|
thread = threading.Thread(target=SubprocessGenerator.launch_thread, args=(generator,) ) |
|
thread.daemon = True |
|
thread.start() |
|
|
|
while not all ([generator._is_started() for generator in generator_list]): |
|
time.sleep(0.005) |
|
|
|
def __init__(self, generator_func, user_param=None, prefetch=2, start_now=True): |
|
super().__init__() |
|
self.prefetch = prefetch |
|
self.generator_func = generator_func |
|
self.user_param = user_param |
|
self.sc_queue = multiprocessing.Queue() |
|
self.cs_queue = multiprocessing.Queue() |
|
self.p = None |
|
if start_now: |
|
self._start() |
|
|
|
def _start(self): |
|
if self.p == None: |
|
user_param = self.user_param |
|
self.user_param = None |
|
p = multiprocessing.Process(target=self.process_func, args=(user_param,) ) |
|
p.daemon = True |
|
p.start() |
|
self.p = p |
|
|
|
def _is_started(self): |
|
return self.p is not None |
|
|
|
def process_func(self, user_param): |
|
self.generator_func = self.generator_func(user_param) |
|
while True: |
|
while self.prefetch > -1: |
|
try: |
|
gen_data = next (self.generator_func) |
|
except StopIteration: |
|
self.cs_queue.put (None) |
|
return |
|
self.cs_queue.put (gen_data) |
|
self.prefetch -= 1 |
|
self.sc_queue.get() |
|
self.prefetch += 1 |
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def __getstate__(self): |
|
self_dict = self.__dict__.copy() |
|
del self_dict['p'] |
|
return self_dict |
|
|
|
def __next__(self): |
|
self._start() |
|
gen_data = self.cs_queue.get() |
|
if gen_data is None: |
|
self.p.terminate() |
|
self.p.join() |
|
raise StopIteration() |
|
self.sc_queue.put (1) |
|
return gen_data |
|
|