Source code for sqlalchemy_aio.asyncio

import asyncio
import threading
import warnings
from concurrent.futures import CancelledError
from functools import partial

import outcome

from .base import AsyncEngine, ThreadWorker
from .exc import AlreadyQuit, SQLAlchemyAioDeprecationWarning


class Request:
    def __init__(self, func):
        self.func = func
        self.finished = asyncio.Event()
        self.response = None

    def set_finished(self):
        """Needed to be executed in the same thread as the loop.
        Since Event() is not thread-safe.
        """
        self.finished.set()


class AsyncioThreadWorker(ThreadWorker):
    def __init__(self, *, branch_from=None):
        self._loop = asyncio.get_event_loop()

        if branch_from is None:
            self._request_queue = asyncio.Queue(1)
            self._thread = threading.Thread(target=self.thread_fn, daemon=True)
            self._thread.start()
        else:
            self._request_queue = branch_from._request_queue
            self._thread = branch_from._thread

        self._branched = branch_from is not None
        self._has_quit = False

    def thread_fn(self):
        while True:
            fut = asyncio.run_coroutine_threadsafe(
                self._request_queue.get(), self._loop)
            try:
                request = fut.result()
            except CancelledError:
                continue

            if request.func is not None:
                request.response = outcome.capture(request.func)

                self._loop.call_soon_threadsafe(request.set_finished)
            else:
                self._loop.call_soon_threadsafe(request.set_finished)
                break

    async def run(self, func, args=(), kwargs=None):
        if self._has_quit:
            raise AlreadyQuit

        if kwargs:
            func = partial(func, *args, **kwargs)
        elif args:
            func = partial(func, *args)

        request = Request(func)
        await self._request_queue.put(request)
        await request.finished.wait()
        return request.response.unwrap()

    async def quit(self):
        if self._has_quit:
            raise AlreadyQuit

        self._has_quit = True

        if self._branched:
            return

        stop = Request(None)
        await self._request_queue.put(stop)
        await stop.finished.wait()


[docs]class AsyncioEngine(AsyncEngine): """Mostly like :class:`sqlalchemy.engine.Engine` except some of the methods are coroutines.""" def __init__(self, pool, dialect, url, logging_name=None, echo=None, execution_options=None, **kwargs): super().__init__( pool, dialect, url, logging_name, echo, execution_options, **kwargs) def _make_worker(self, *, branch_from=None): return AsyncioThreadWorker(branch_from=branch_from)