Source code for pymprpc.server.server

"""定义mprpc服务器基类.

+ File: core.py
+ Version: 0.5
+ Author: hsz
+ Email: hsz1273327@gmail.com
+ Copyright: 2018-02-08 hsz
+ License: MIT
+ History

    + 2018-01-23 created by hsz
    + 2018-01-23 version-0.5 by hsz
"""
import os
import pydoc
import platform
import inspect
from functools import partial
from concurrent import futures
from ssl import SSLContext
from signal import (
    SIGTERM, SIGINT
)
from typing import (
    Tuple,
    List,
    Optional,
    Any,
    Callable
)

from pymprpc.errors import (
    RpcException,
    NotFindError,
    ParamError,
    RPCRuntimeError
)

from .protocol import MPProtocolServer

from .utils import (
    list_public_methods,
    resolve_dotted_attribute
)

from .log import (
    logger
)

if platform.system() == "Windows":
    try:
        import aio_windows_patch as asyncio
    except:
        import warnings
        warnings.warn(
            "you should install aio_windows_patch to support windows",
            RuntimeWarning,
            stacklevel=3)
        import asyncio

else:
    import asyncio

try:
    import uvloop
    asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError:
    pass


class BaseServer:
    """mprpc的服务器基类,负责启动服务器和管理连接.

    Attributes:
        addr (Tuple[str, int]): - 服务器的地址
        loop (asyncio.AbstractEventLoop): - 事件循环
        backlog (int): - 最大的连接数
        graceful_shutdown_timeout (int): - 优雅关闭的时间

    """

    version = "0.0.1"

    def __init__(self,
                 addr: Tuple[str, int], *,
                 loop: Optional[asyncio.AbstractEventLoop]=None,
                 func_executor: Optional[futures.Executor]=None,
                 auth: List[Tuple[str, str]]=[("admin", "admin")],
                 timeout: float = 180.0,
                 debug: bool=False,
                 compreser: Optional[str]=None,
                 ssl: Optional[SSLContext]=None,
                 backlog: int=100,
                 graceful_shutdown_timeout: int=10):
        """初始化服务器设置.

        Parameters:
            addr (Tuple[str, int]): - 服务器启动地址
            loop (Optional[asyncio.AbstractEventLoop]): - 启动服务的事件循环,默认为None
            func_executor (Optional[futures.Executor]): - 函数,方法等的执行器,
            默认为`ProcessPoolExecutor`

            auth (List[Tuple[str, str]]): - 验证信息,默认为`[("admin", "admin")]`
            timeout (float): - 连接的过期时间,默认180s
            debug (bool): - 是否使用debug模式,默认为False
            compreser: (Optional[str]): - 使用什么压缩函数,默认为None
            ssl (Optional[SSLContext]): - 是否使用ssl,默认为None
            backlog (int): - 设置默认的连接数缓冲大小,默认为100
            graceful_shutdown_timeout (int): - 优雅关闭延迟时间,默认10

        """
        # public
        self.addr = addr
        self.loop = loop or asyncio.get_event_loop()
        self.ssl = ssl
        self.backlog = backlog
        self.graceful_shutdown_timeout = graceful_shutdown_timeout
        self.pid = None
        self.funcs = {}
        self.instance = None

        self.rpc_server = None
        self.running = False
        # protected
        self._func_executor = func_executor or futures.ProcessPoolExecutor()
        self._protocol = partial(
            MPProtocolServer,
            method_wrapper=self,
            loop=loop,
            auth=auth,
            timeout=timeout,
            debug=debug,
            compreser=compreser)
        self.loop.set_default_executor(func_executor)
        if debug is True:
            self.loop.set_debug(True)

    def clean(self):
        """服务结束后清理服务器."""
        # 服务结束阶段
        if self.running is False:
            return False
        logger.info("Stopping worker [%s]", self.pid)
        # 关闭server
        self.rpc_server.close()
        self.loop.run_until_complete(self.rpc_server.wait_closed())
        # 关闭连接
        # 完成所有空转连接的关闭工作
        for connection in MPProtocolServer.CONNECTIONS:
            connection.shutdown()

        # 等待由graceful_shutdown_timeout设置的时间
        # 让还在运转的连接关闭,防止连接一直被挂起
        start_shutdown = 0
        while MPProtocolServer.CONNECTIONS and (
                start_shutdown < self.graceful_shutdown_timeout):
            self.loop.run_until_complete(asyncio.sleep(0.1))
            start_shutdown = start_shutdown + 0.1

        # 在等待graceful_shutdown_timeout设置的时间后
        # 强制关闭所有的连接
        # for conn in MPProtocolServer.CONNECTIONS:
        #     conn.close()
        # 收尾阶段关闭所有协程,
        self.loop.run_until_complete(self.loop.shutdown_asyncgens())
        # 关闭loop
        self.loop.close()
        logger.info("Stopped worker [%s]", self.pid)
        return True

    def run_forever(self):
        """执行服务器."""
        server_coroutine = self.loop.create_server(
            self._protocol,
            self.addr[0],
            self.addr[1],
            ssl=self.ssl,
            backlog=self.backlog
        )
        try:
            self.rpc_server = self.loop.run_until_complete(server_coroutine)
        except BaseException:
            logger.exception("Unable to start server")
            return
        _singals = (SIGINT, SIGTERM)
        for _signal in _singals:
            try:
                self.loop.add_signal_handler(_signal, self.loop.stop)
            except NotImplementedError as ni:
                logger.warning('tried to use loop.add_signal_handler '
                               'but it is not implemented on this platform.')
        self.pid = os.getpid()
        logger.info(
            "Server @{host}:{port}".format(
                host=self.addr[0],
                port=self.addr[1]
            )
        )
        logger.info(
            """Starting worker [{pid}]
  _ __ ___  _ __  _ __ _ __   ___ 
 | '_ ` _ \| '_ \| '__| '_ \ / __|
 | | | | | | |_) | |  | |_) | (__ 
 |_| |_| |_| .__/|_|  | .__/ \___|
           |_|        |_|                 
""".format(pid=self.pid))
        self.running = True
        self.loop.run_forever()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.clean()

    # --------------------自省--------------------------------
    def register_introspection_functions(self)->None:
        """注册自省函数到函数字典."""
        self.funcs.update({
            'system.listMethods': self.system_listMethods,
            'system.methodSignature': self.system_methodSignature,
            'system.methodHelp': self.system_methodHelp,
            'system.lenConnections': self.system_lenConnections,
            'system.lenUndoneTasks': self.system_lenUndoneTasks
        })

    def system_lenConnections(self)->int:
        """获取所有的连接数目."""
        return len(MPProtocolServer.CONNECTIONS)

    def system_lenUndoneTasks(self)->int:
        """获取未完成的任务数目."""
        return len([i for i in MPProtocolServer.TASKS if not i.done()])

    def system_listMethods(self)->List[str]:
        """返回所有注册的函数的名字.

        system.listMethods() => ['add', 'subtract', 'multiple']

        Return:
            (list): - 被注册的可调用函数

        """
        methods = set(self.funcs.keys())
        if self.instance is not None:
            if hasattr(self.instance, '_listMethods'):
                methods |= set(self.instance._listMethods())
            elif not hasattr(self.instance, '_dispatch'):
                methods |= set(list_public_methods(self.instance))
        return sorted(methods)

    def system_methodSignature(self, method_name: str)->str:
        """获取函数的签名.

        system.methodSignature('add') => [double, int, int]

        Parameters:
            method_name (str): - 要查看的函数名

        Returns:
            (str): - 签名文本

        """
        method = None
        if method_name in self.funcs:
            method = self.funcs[method_name]
        elif self.instance is not None:
            try:
                method = resolve_dotted_attribute(
                    self.instance,
                    method_name,
                    self.allow_dotted_names
                )
            except AttributeError:
                pass
        if method is None:
            return ""
        else:
            return str(inspect.signature(method))

    def system_methodHelp(self, method_name: str)->str:
        """将docstring返回.

        system.methodHelp('add') => "Adds two integers together"

        Return:
            (str): - 函数的帮助文本

        """
        method = None
        if method_name in self.funcs:
            method = self.funcs[method_name]
        elif self.instance is not None:
            try:
                method = resolve_dotted_attribute(
                    self.instance,
                    method_name,
                    self.allow_dotted_names
                )
            except AttributeError:
                pass
        if method is None:
            return ""
        else:
            return pydoc.getdoc(method)

    # ---------------------注册函数------------------------

    def register_instance(self, instance: Any, allow_dotted_names: bool=False):
        """注册一个实例用于执行,注意只能注册一个.

        Parameters:
            instance (Any): - 将一个类的实例注册到rpc
            allow_dotted_names (bool): 是否允许带`.`号的名字

        """
        if self.instance:
            raise RuntimeError("can only register one instance")
        self.instance = instance
        self.allow_dotted_names = allow_dotted_names
        return True

    def register_function(self, name: Optional[str]=None):
        """注册函数.

        Parameters:
            name (Optional[str]): - 将函数注册到的名字,如果为None,name就用其原来的名字

        """
        def wrap(function: Callable)->Any:
            nonlocal name
            if name is None:
                name = function.__name__
            self.funcs[name] = function
            return function
        return wrap

    def set_executor(self, executor: futures.Executor):
        """设置计算密集型任务的执行器.

        Parameters:
            executor (futures.Executor): - 函数调用的执行器

        """
        self.loop.set_default_executor(executor)
        self._func_executor = executor
        return True

    async def apply(self, ID: str, method: str, *args: Any, **kwargs: Any):
        """执行注册的函数或者实例的方法.

        如果函数或者方法是协程则执行协程,如果是函数则使用执行器执行,默认使用的是多进程.

        Parameters:
            ID (str): 任务的ID
            method (str): 任务调用的函数名
            args (Any): 位置参数
            kwargs (Any): 关键字参数

        Raise:
            (RPCRuntimeError): - 当执行调用后抛出了异常,那就算做RPC运行时异常

        Return:
            (Any): - 被调用函数的返回

        """
        func = None
        try:
            # check to see if a matching function has been registered
            func = self.funcs[method]
        except KeyError:
            if self.instance is not None:
                # check for a _dispatch method
                try:
                    func = resolve_dotted_attribute(
                        self.instance,
                        method,
                        self.allow_dotted_names)
                except AttributeError:
                    pass
        if func is not None:
            sig = inspect.signature(func)
            try:
                sig.bind(*args, **kwargs)
            except:
                raise ParamError(
                    "args can not bind to method {}".format(method), ID)
            if method.startswith("system."):
                try:
                    result = func(*args, **kwargs)
                except Exception as e:
                    raise RPCRuntimeError(
                        'Error:{} happend in method {}'.format(
                            e.__class__.__name__,
                            method
                        ),
                        ID
                    )
                else:
                    return result
            if inspect.iscoroutinefunction(func):
                try:
                    result = await func(*args, **kwargs)
                except Exception as e:
                    raise RPCRuntimeError(
                        'Error:{} happend in method {}'.format(
                            e.__class__.__name__,
                            method
                        ),
                        ID
                    )
                else:
                    return result
            elif inspect.isasyncgenfunction(func):
                result = func(*args, **kwargs)
                return result
            elif inspect.isfunction(func) or inspect.ismethod(func):
                try:
                    f = partial(func, *args, **kwargs)
                    result = await self.loop.run_in_executor(None, f)
                except Exception as e:
                    raise RPCRuntimeError(
                        'Error:{} happend in method {}'.format(
                            e.__class__.__name__,
                            method
                        ),
                        ID
                    )
                else:
                    return result
            else:
                raise RpcException('method "%s" is not supported' % method)
        else:
            raise NotFindError('method "%s" is not supported' % method, ID)