1 #-*-coding:utf-8-*-2 2 3 import threading 4 import queue 5 import itertools 6 import os 7 import time 8 9 10 RUN = 0 11 CLOSE = 1 12 TERMINATE = 2 13 job_counter = itertools.count() 14 15 16 class Pool(object): 17 18 def __init__(self, max_thread_num=None): 19 self.__chk_thread_num(max_thread_num) 20 self._setup_queues() 21 self._cache = {} # 存储任务运行结果ApplyResult对象的cache 22 self._state = RUN 23 self._max_num = max_thread_num # 线程上限 24 self._pool = [] # 真.线程池! 25 self._add_thread_to_pool() 26 27 # 监控结果cache的handle线程 28 self._worker_handler = threading.Thread( 29 target=Pool._handle_workers, 30 args=(self, ) 31 ) 32 self.__init_handler(self._worker_handler) 33 34 # 监控输入任务的taskqueue,写入 inqueue的handle线程 35 self._task_handler = threading.Thread( 36 target=Pool._handle_tasks, 37 args=(self._taskqueue, self._quick_put, self._outqueue, 38 self._pool, self._cache) 39 ) 40 self.__init_handler(self._task_handler) 41 42 # 监控ouqueue,写入运行结果的handle线程 43 self._result_handler = threading.Thread( 44 target=Pool._handle_results, 45 args=(self._quick_get, self._cache) 46 ) 47 self.__init_handler(self._result_handler) 48 49 def __chk_thread_num(self, max_thread_num): 50 # 检查最大线程数是否合法 51 if max_thread_num is None: 52 max_thread_num = os.cpu_count() or 1 53 if max_thread_num < 1: 54 raise ValueError("Number of thread should bigger than 1") 55 56 @staticmethod 57 def __init_handler(handler): 58 handler.daemon = True 59 handler._state = RUN 60 handler.start() 61 62 def _add_thread_to_pool(self): 63 # 给线程池补充线程 64 for i in range(self._max_num - len(self._pool)): 65 thread_worker = self.Process(target=self.worker) 66 self._pool.append(thread_worker) 67 thread_worker.daemon = True 68 thread_worker.start() 69 70 def _setup_queues(self): 71 self._inqueue = queue.Queue() # worker的输入队列,保护task信息 72 self._outqueue = queue.Queue() # worker的输出队列,包含task运行结果 73 self._quick_put = self._inqueue.put 74 self._quick_get = self._outqueue.get 75 self._taskqueue = queue.Queue() # 线程池获得的原始taskqueue 76 77 def Process(self, *args, **kwds): 78 return threading.Thread(*args, **kwds) 79 80 def apply(self, func, args=(), kwds={}): 81 # apply 直接调用get, 82 # 阻塞到任务运行完,再直接返回 83 assert self._state == RUN 84 return self.apply_async(func, args, kwds).get() 85 86 def apply_async(self, func, args=(), kwds={}, callback=None): 87 # 异步调用,将任务信息放到_taskqueue里之后执行, 88 # 返回的是异步调用结果对象,可以通过该对象再取真正的结果 89 if self._state != RUN: 90 raise ValueError("Pool not running") 91 result = ApplyResult(self._cache, callback) 92 93 self._taskqueue.put((result._job, None, func, args, kwds)) 94 return result 95 96 @staticmethod 97 def _handle_workers(pool): 98 thread = threading.current_thread() 99 100 # 当线程池被close或teminate后 101 # 间隔检查存储结果cache的情况 102 # 若cache为空则发送None通知 tasks handle停止 103 while thread._state == RUN or (pool._cache and thread._state != TERMINATE): 104 time.sleep(0.1) 105 pool._taskqueue.put(None) 106 107 @staticmethod 108 def _handle_tasks(taskqueue, put, outqueue, pool, cache): 109 thread = threading.current_thread() 110 try: 111 # taskqueue.get到None时结束循环 112 for task_info in iter(taskqueue.get, None): 113 if thread._state: 114 break 115 try: 116 put(task_info) 117 except Exception as e: 118 job, ind = task_info[:2] 119 try: 120 cache[job]._set((False, e)) 121 except KeyError: 122 pass 123 except Exception as ex: 124 job, ind = (0, 0) 125 if job in cache: 126 cache[job]._set((False, ex)) 127 128 try: 129 outqueue.put(None) # 通知result handle结束 130 # 通知pool里的所有worker结束 131 for p in pool: 132 put(None) 133 except OSError: 134 print(‘task handler got OSError when sending sentinels‘) 135 136 @staticmethod 137 def _handle_results(get, cache): 138 thread = threading.current_thread() 139 while 1: 140 try: 141 task_rtn = get() 142 except (OSError, EOFError): 143 return 144 145 if thread._state: 146 assert thread._state == TERMINATE 147 break 148 149 if task_rtn is None: 150 break 151 152 job, i, obj = task_rtn 153 try: 154 cache[job]._set(obj) 155 except KeyError: 156 pass 157 158 # 将剩余在cache里的结果全部处理完 159 while cache and thread._state != TERMINATE: 160 try: 161 task_rtn = get() 162 except (OSError, EOFError): 163 return 164 165 if task_rtn is None: 166 continue 167 job, i, obj = task_rtn 168 try: 169 cache[job]._set(obj) 170 except KeyError: 171 pass 172 173 def close(self): 174 if self._state == RUN: 175 self._state = CLOSE 176 self._worker_handler._state = CLOSE 177 178 def terminate(self): 179 self._state = TERMINATE 180 self._worker_handler._state = TERMINATE 181 self._terminate_pool(self._inqueue, self._outqueue, self._pool, 182 self._worker_handler, self._task_handler, 183 self._result_handler, self._cache) 184 185 def join(self): 186 assert self._state in (CLOSE, TERMINATE) 187 self._worker_handler.join() 188 self._task_handler.join() 189 self._result_handler.join() 190 for p in self._pool: 191 p.join() 192 193 @staticmethod 194 def _help_stuff_finish(inqueue, size): 195 # 清空inqueue,放入标志None,通知pool里的所有worker结束 196 with inqueue.not_empty: 197 inqueue.queue.clear() 198 inqueue.queue.extend([None] * size) 199 inqueue.not_empty.notify_all() 200 201 @classmethod 202 def _terminate_pool(cls, inqueue, outqueue, pool, 203 worker_handler, task_handler, result_handler, cache): 204 worker_handler._state = TERMINATE 205 task_handler._state = TERMINATE 206 207 cls._help_stuff_finish(inqueue, len(pool)) 208 209 assert result_handler.is_alive() or len(cache) == 0 210 211 result_handler._state = TERMINATE 212 outqueue.put(None) # 终止标志 213 214 # 等到三个监控handle都运行终止 215 # 防止有worker还没运行结束 216 for handler in (worker_handler, task_handler, result_handler): 217 if threading.current_thread() is not handler: 218 handler.join() 219 220 def worker(self): 221 # worker从inqueue中获取任务信息并执行 222 # 将结果写入outqueue 223 while 1: 224 try: 225 task_info = self._inqueue.get() 226 except (EOFError, OSError): 227 break 228 229 if task_info is None: 230 break 231 232 job, i, func, args, kwds = task_info 233 try: 234 result = (True, func(*args, **kwds)) 235 except Exception as e: 236 print(‘Exception occurred: %s\n%s‘ % (e, e.__traceback__)) 237 result = (False, e) 238 try: 239 self._outqueue.put((job, i, result)) 240 except Exception as e: 241 err_msg = "Exception occurred while sending %s: %s" % (result[1], e) 242 print(err_msg) 243 self._outqueue.put((job, i, (False, err_msg))) 244 245 246 class ApplyResult(object): 247 def __init__(self, cache, callback): 248 self._event = threading.Event() 249 self._job = next(job_counter) 250 self._cache = cache 251 self._callback = callback 252 self._success = False 253 self._value = None 254 cache[self._job] = self 255 256 def ready(self): 257 return self._event.is_set() 258 259 def wait(self, timeout=None): 260 self._event.wait(timeout) 261 262 def get(self, timeout=None): 263 # 当_set结束,task有结果后结束阻塞 264 self.wait(timeout) 265 if not self.ready(): 266 print (‘Timeout!‘) 267 if self._success: 268 return self._value 269 else: 270 raise self._value 271 272 def _set(self, obj): 273 # 处理任务运行后的结果 274 self._success, self._value = obj 275 if self._callback and self._success: 276 self._callback(self._value) 277 self._event.set() 278 del self._cache[self._job]
参考multiprocessing的Pool,简单劣化版的替换成了线程的池子的版本...
时间: 2024-10-13 16:21:44