Tweak incoming/outgoing workers
This commit is contained in:
		
							parent
							
								
									0b6556e54a
								
							
						
					
					
						commit
						0696268d0b
					
				
					 5 changed files with 30 additions and 46 deletions
				
			
		|  | @ -69,13 +69,11 @@ def _set_next_try( | ||||||
| 
 | 
 | ||||||
| async def fetch_next_incoming_activity( | async def fetch_next_incoming_activity( | ||||||
|     db_session: AsyncSession, |     db_session: AsyncSession, | ||||||
|     in_flight: set[int], |  | ||||||
| ) -> models.IncomingActivity | None: | ) -> models.IncomingActivity | None: | ||||||
|     where = [ |     where = [ | ||||||
|         models.IncomingActivity.next_try <= now(), |         models.IncomingActivity.next_try <= now(), | ||||||
|         models.IncomingActivity.is_errored.is_(False), |         models.IncomingActivity.is_errored.is_(False), | ||||||
|         models.IncomingActivity.is_processed.is_(False), |         models.IncomingActivity.is_processed.is_(False), | ||||||
|         models.IncomingActivity.id.not_in(in_flight), |  | ||||||
|     ] |     ] | ||||||
|     q_count = await db_session.scalar( |     q_count = await db_session.scalar( | ||||||
|         select(func.count(models.IncomingActivity.id)).where(*where) |         select(func.count(models.IncomingActivity.id)).where(*where) | ||||||
|  | @ -144,11 +142,11 @@ class IncomingActivityWorker(Worker[models.IncomingActivity]): | ||||||
|         self, |         self, | ||||||
|         db_session: AsyncSession, |         db_session: AsyncSession, | ||||||
|     ) -> models.IncomingActivity | None: |     ) -> models.IncomingActivity | None: | ||||||
|         return await fetch_next_incoming_activity(db_session, self.in_flight_ids()) |         return await fetch_next_incoming_activity(db_session) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| async def loop() -> None: | async def loop() -> None: | ||||||
|     await IncomingActivityWorker(workers_count=1).run_forever() |     await IncomingActivityWorker().run_forever() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|  |  | ||||||
|  | @ -170,13 +170,11 @@ def _set_next_try( | ||||||
| 
 | 
 | ||||||
| async def fetch_next_outgoing_activity( | async def fetch_next_outgoing_activity( | ||||||
|     db_session: AsyncSession, |     db_session: AsyncSession, | ||||||
|     in_fligh: set[int], |  | ||||||
| ) -> models.OutgoingActivity | None: | ) -> models.OutgoingActivity | None: | ||||||
|     where = [ |     where = [ | ||||||
|         models.OutgoingActivity.next_try <= now(), |         models.OutgoingActivity.next_try <= now(), | ||||||
|         models.OutgoingActivity.is_errored.is_(False), |         models.OutgoingActivity.is_errored.is_(False), | ||||||
|         models.OutgoingActivity.is_sent.is_(False), |         models.OutgoingActivity.is_sent.is_(False), | ||||||
|         models.OutgoingActivity.id.not_in(in_fligh), |  | ||||||
|     ] |     ] | ||||||
|     q_count = await db_session.scalar( |     q_count = await db_session.scalar( | ||||||
|         select(func.count(models.OutgoingActivity.id)).where(*where) |         select(func.count(models.OutgoingActivity.id)).where(*where) | ||||||
|  | @ -289,14 +287,14 @@ class OutgoingActivityWorker(Worker[models.OutgoingActivity]): | ||||||
|         self, |         self, | ||||||
|         db_session: AsyncSession, |         db_session: AsyncSession, | ||||||
|     ) -> models.OutgoingActivity | None: |     ) -> models.OutgoingActivity | None: | ||||||
|         return await fetch_next_outgoing_activity(db_session, self.in_flight_ids()) |         return await fetch_next_outgoing_activity(db_session) | ||||||
| 
 | 
 | ||||||
|     async def startup(self, db_session: AsyncSession) -> None: |     async def startup(self, db_session: AsyncSession) -> None: | ||||||
|         await _send_actor_update_if_needed(db_session) |         await _send_actor_update_if_needed(db_session) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| async def loop() -> None: | async def loop() -> None: | ||||||
|     await OutgoingActivityWorker(workers_count=3).run_forever() |     await OutgoingActivityWorker().run_forever() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|  |  | ||||||
|  | @ -12,30 +12,9 @@ T = TypeVar("T") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Worker(Generic[T]): | class Worker(Generic[T]): | ||||||
|     def __init__(self, workers_count: int) -> None: |     def __init__(self) -> None: | ||||||
|         self._loop = asyncio.get_event_loop() |         self._loop = asyncio.get_event_loop() | ||||||
|         self._in_flight: set[int] = set() |  | ||||||
|         self._queue: asyncio.Queue[T] = asyncio.Queue(maxsize=1) |  | ||||||
|         self._stop_event = asyncio.Event() |         self._stop_event = asyncio.Event() | ||||||
|         self._workers_count = workers_count |  | ||||||
| 
 |  | ||||||
|     async def _consumer(self, db_session: AsyncSession) -> None: |  | ||||||
|         while not self._stop_event.is_set(): |  | ||||||
|             message = await self._queue.get() |  | ||||||
|             try: |  | ||||||
|                 await self.process_message(db_session, message) |  | ||||||
|             finally: |  | ||||||
|                 self._in_flight.remove(message.id)  # type: ignore |  | ||||||
|                 self._queue.task_done() |  | ||||||
| 
 |  | ||||||
|     async def _producer(self, db_session: AsyncSession) -> None: |  | ||||||
|         while not self._stop_event.is_set(): |  | ||||||
|             next_message = await self.get_next_message(db_session) |  | ||||||
|             if next_message: |  | ||||||
|                 self._in_flight.add(next_message.id)  # type: ignore |  | ||||||
|                 await self._queue.put(next_message) |  | ||||||
|             else: |  | ||||||
|                 await asyncio.sleep(1) |  | ||||||
| 
 | 
 | ||||||
|     async def process_message(self, db_session: AsyncSession, message: T) -> None: |     async def process_message(self, db_session: AsyncSession, message: T) -> None: | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  | @ -46,8 +25,16 @@ class Worker(Generic[T]): | ||||||
|     async def startup(self, db_session: AsyncSession) -> None: |     async def startup(self, db_session: AsyncSession) -> None: | ||||||
|         return None |         return None | ||||||
| 
 | 
 | ||||||
|     def in_flight_ids(self) -> set[int]: |     async def _main_loop(self, db_session: AsyncSession) -> None: | ||||||
|         return self._in_flight |         while not self._stop_event.is_set(): | ||||||
|  |             next_message = await self.get_next_message(db_session) | ||||||
|  |             if next_message: | ||||||
|  |                 await self.process_message(db_session, next_message) | ||||||
|  |             else: | ||||||
|  |                 await asyncio.sleep(1) | ||||||
|  | 
 | ||||||
|  |     async def _until_stopped(self) -> None: | ||||||
|  |         await self._stop_event.wait() | ||||||
| 
 | 
 | ||||||
|     async def run_forever(self) -> None: |     async def run_forever(self) -> None: | ||||||
|         signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) |         signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) | ||||||
|  | @ -59,13 +46,14 @@ class Worker(Generic[T]): | ||||||
| 
 | 
 | ||||||
|         async with async_session() as db_session: |         async with async_session() as db_session: | ||||||
|             await self.startup(db_session) |             await self.startup(db_session) | ||||||
|             self._loop.create_task(self._producer(db_session)) |             task = self._loop.create_task(self._main_loop(db_session)) | ||||||
|             for _ in range(self._workers_count): |             stop_task = self._loop.create_task(self._until_stopped()) | ||||||
|                 self._loop.create_task(self._consumer(db_session)) |  | ||||||
| 
 | 
 | ||||||
|             await self._stop_event.wait() |             done, pending = await asyncio.wait( | ||||||
|             logger.info("Waiting for tasks to finish") |                 {task, stop_task}, return_when=asyncio.FIRST_COMPLETED | ||||||
|             await self._queue.join() |             ) | ||||||
|  |             logger.info(f"Waiting for tasks to finish {done=}/{pending=}") | ||||||
|  |             await asyncio.sleep(5) | ||||||
|             tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] |             tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] | ||||||
|             logger.info(f"Cancelling {len(tasks)} tasks") |             logger.info(f"Cancelling {len(tasks)} tasks") | ||||||
|             [task.cancel() for task in tasks] |             [task.cancel() for task in tasks] | ||||||
|  |  | ||||||
|  | @ -24,7 +24,7 @@ from tests.utils import setup_remote_actor_as_follower | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| async def _process_next_incoming_activity(db_session: AsyncSession) -> None: | async def _process_next_incoming_activity(db_session: AsyncSession) -> None: | ||||||
|     next_activity = await fetch_next_incoming_activity(db_session, set()) |     next_activity = await fetch_next_incoming_activity(db_session) | ||||||
|     assert next_activity |     assert next_activity | ||||||
|     await process_next_incoming_activity(db_session, next_activity) |     await process_next_incoming_activity(db_session, next_activity) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -70,7 +70,7 @@ async def test_process_next_outgoing_activity__no_next_activity( | ||||||
|     respx_mock: respx.MockRouter, |     respx_mock: respx.MockRouter, | ||||||
|     async_db_session: AsyncSession, |     async_db_session: AsyncSession, | ||||||
| ) -> None: | ) -> None: | ||||||
|     next_activity = await fetch_next_outgoing_activity(async_db_session, set()) |     next_activity = await fetch_next_outgoing_activity(async_db_session) | ||||||
|     assert next_activity is None |     assert next_activity is None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -94,7 +94,7 @@ async def test_process_next_outgoing_activity__server_200( | ||||||
| 
 | 
 | ||||||
|     # When processing the next outgoing activity |     # When processing the next outgoing activity | ||||||
|     # Then it is processed |     # Then it is processed | ||||||
|     next_activity = await fetch_next_outgoing_activity(async_db_session, set()) |     next_activity = await fetch_next_outgoing_activity(async_db_session) | ||||||
|     assert next_activity |     assert next_activity | ||||||
|     await process_next_outgoing_activity(async_db_session, next_activity) |     await process_next_outgoing_activity(async_db_session, next_activity) | ||||||
| 
 | 
 | ||||||
|  | @ -129,7 +129,7 @@ async def test_process_next_outgoing_activity__webmention( | ||||||
| 
 | 
 | ||||||
|     # When processing the next outgoing activity |     # When processing the next outgoing activity | ||||||
|     # Then it is processed |     # Then it is processed | ||||||
|     next_activity = await fetch_next_outgoing_activity(async_db_session, set()) |     next_activity = await fetch_next_outgoing_activity(async_db_session) | ||||||
|     assert next_activity |     assert next_activity | ||||||
|     await process_next_outgoing_activity(async_db_session, next_activity) |     await process_next_outgoing_activity(async_db_session, next_activity) | ||||||
| 
 | 
 | ||||||
|  | @ -165,7 +165,7 @@ async def test_process_next_outgoing_activity__error_500( | ||||||
| 
 | 
 | ||||||
|     # When processing the next outgoing activity |     # When processing the next outgoing activity | ||||||
|     # Then it is processed |     # Then it is processed | ||||||
|     next_activity = await fetch_next_outgoing_activity(async_db_session, set()) |     next_activity = await fetch_next_outgoing_activity(async_db_session) | ||||||
|     assert next_activity |     assert next_activity | ||||||
|     await process_next_outgoing_activity(async_db_session, next_activity) |     await process_next_outgoing_activity(async_db_session, next_activity) | ||||||
| 
 | 
 | ||||||
|  | @ -203,7 +203,7 @@ async def test_process_next_outgoing_activity__errored( | ||||||
| 
 | 
 | ||||||
|     # When processing the next outgoing activity |     # When processing the next outgoing activity | ||||||
|     # Then it is processed |     # Then it is processed | ||||||
|     next_activity = await fetch_next_outgoing_activity(async_db_session, set()) |     next_activity = await fetch_next_outgoing_activity(async_db_session) | ||||||
|     assert next_activity |     assert next_activity | ||||||
|     await process_next_outgoing_activity(async_db_session, next_activity) |     await process_next_outgoing_activity(async_db_session, next_activity) | ||||||
| 
 | 
 | ||||||
|  | @ -218,7 +218,7 @@ async def test_process_next_outgoing_activity__errored( | ||||||
|     assert outgoing_activity.is_errored is True |     assert outgoing_activity.is_errored is True | ||||||
| 
 | 
 | ||||||
|     # And it is skipped from processing |     # And it is skipped from processing | ||||||
|     next_activity = await fetch_next_outgoing_activity(async_db_session, set()) |     next_activity = await fetch_next_outgoing_activity(async_db_session) | ||||||
|     assert next_activity is None |     assert next_activity is None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -241,7 +241,7 @@ async def test_process_next_outgoing_activity__connect_error( | ||||||
| 
 | 
 | ||||||
|     # When processing the next outgoing activity |     # When processing the next outgoing activity | ||||||
|     # Then it is processed |     # Then it is processed | ||||||
|     next_activity = await fetch_next_outgoing_activity(async_db_session, set()) |     next_activity = await fetch_next_outgoing_activity(async_db_session) | ||||||
|     assert next_activity |     assert next_activity | ||||||
|     await process_next_outgoing_activity(async_db_session, next_activity) |     await process_next_outgoing_activity(async_db_session, next_activity) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue