"""定义asyncio环境下使用的mprpc的客户端.
+ File: aio.py
+ Version: 0.5
+ Author: hsz
+ Email: hsz1273327@gmail.com
+ Copyright: 2018-02-08 hsz
+ License: MIT
+ History
+ 2018-01-23 created by hsz
+ 2018-02-08 version-0.5 by hsz
"""
import asyncio
import uuid
import zlib
import bz2
import lzma
import warnings
from typing import (
Optional,
Dict,
Any
)
from urllib.parse import urlparse
from pymprpc.errors import (
MprpcException,
abort
)
from pymprpc.mixins.encoder_decoder_mixin import EncoderDecoderMixin
from .utils import Method
[docs]class AsyncRPC(EncoderDecoderMixin):
"""异步的RPC客户端.
可以通过设置debug=True规定传输使用json且显示中间过程中的信息.
Attributes:
SEPARATOR (bytes): - 协议规定的请求响应终止符
VERSION (str): - 协议版本,以`x.x`的形式表现版本
COMPRESERS (Dict[str,model]): - 支持的压缩解压工具
username (Optional[str]): - 登录远端的用户名
password (Optional[str]): - 登录远端的密码
hostname (str): - 远端的主机地址
port (int): - 远端的主机端口
loop (asyncio.AbstractEventLoop): - 使用的事件循环
debug (bool): - 是否使用debug模式
compreser (Optional[str]): - 是否使用压缩工具压缩传输信息,以及压缩工具是什么
heart_beat (Optional[int]): - 是否使用心跳机制确保连接不会因过期而断开
closed (bool): - 客户端是否已经关闭或者还未开始运转
reader (asyncio.StreamReader): - 流读取对象
writer (asyncio.StreamWriter): - 流写入对象
tasks (Dict[str,asyncio.Future]): - 远端执行的任务,保存以ID为键
gens (Dict[str,Any]): - 远端执行的流返回任务,保存以ID为键
gens_res (Dict[str,List[Any]]): - 远端执行的流返回任务的结果,保存以ID为键
remote_info (Dict[str,Any]): - 通过验证后返回的远端服务信息
"""
SEPARATOR = b"##PRO-END##"
COMPRESERS = {
"zlib": zlib,
"bz2": bz2,
"lzma": lzma
}
VERSION = "0.1"
def __init__(self,
addr: str,
loop: Optional[asyncio.AbstractEventLoop]=None,
debug: bool=False,
compreser: Optional[str]=None,
heart_beat: Optional[int]=None):
"""初始化RPC客户端.
Parameters:
addr (str): - 形如`tcp://xxx:xxx@xxx:xxx`的字符串
loop (Optional[asyncio.AbstractEventLoop]): - 事件循环
debug (bool): - 是否使用debug模式,默认为否
compreser(Optional[str]): - 是否使用压缩工具压缩传输信息,以及压缩工具是什么,默认为不使用.
heart_beat (Optional[int]):- 是否使用心跳机制确保连接不会因过期而断开,默认为不使用.
"""
pas = urlparse(addr)
if pas.scheme != "tcp":
raise abort(505, "unsupported scheme for this protocol")
# public
self.username = pas.username
self.password = pas.password
self.hostname = pas.hostname
self.port = pas.port
self.loop = loop or asyncio.get_event_loop()
self.debug = debug
if compreser is not None:
_compreser = self.COMPRESERS.get(compreser)
if _compreser is not None:
self.compreser = _compreser
else:
raise RuntimeError("compreser unsupport")
else:
self.compreser = None
self.heart_beat = heart_beat
self.closed = True
self.reader = None
self.writer = None
self.tasks = {}
self.remote_info = None
# protected
self._gens_queue = {}
self._login_fut = None
self._response_task = None
self._heartbeat_task = None
if self.debug is True:
self.loop.set_debug(True)
async def __aenter__(self):
if self.debug is True:
print('entering context')
await self.connect()
return self
async def __aexit__(self, exc_type, exc, tb):
if self.debug is True:
print('exit context')
self.close()
[docs] async def reconnect(self):
"""断线重连."""
self.clean()
try:
self.writer.close()
except:
pass
self.closed = True
await self.connect()
if self.debug:
print("reconnect to {}".format((self.hostname, self.port)))
[docs] async def connect(self):
"""与远端建立连接.
主要执行的操作有:
+ 将监听响应的协程_response_handler放入事件循环
+ 如果有验证信息则发送验证信息
+ 获取连接建立的返回
"""
self.reader, self.writer = await asyncio.open_connection(
host=self.hostname,
port=self.port,
loop=self.loop)
self.closed = False
self._response_task = asyncio.ensure_future(self._response_handler())
query = {
"MPRPC": self.VERSION,
"AUTH": {
"USERNAME": self.username,
"PASSWORD": self.password
}
}
queryb = self.encoder(query)
if self.debug is True:
print("send auth {}".format(queryb))
self.writer.write(queryb)
self._login_fut = self.loop.create_future()
self.remote_info = await self._login_fut
if self.remote_info is False:
raise abort(501)
if self.heart_beat:
self._heartbeat_task = asyncio.ensure_future(
self._heartbeat_callback())
[docs] def clean(self):
"""清理还在执行或者等待执行的协程."""
for _, i in self.tasks.items():
i.cancel()
self._heartbeat_task.cancel()
self._response_task.cancel()
[docs] def close(self):
"""关闭与远端的连接.
判断标志位closed是否为False,如果是则关闭,否则不进行操作
"""
if self.closed is False:
self.clean()
try:
self.writer.write_eof()
except:
pass
self.writer.close()
self.closed = True
if self.debug is True:
print('close')
if self.debug is True:
print("closed")
# ------------------------心跳操作-------------------------------
async def _heartbeat_callback(self):
"""如果设置了心跳,则调用这个协程."""
query = {
"MPRPC": self.VERSION,
"HEARTBEAT": "ping"
}
queryb = self.encoder(query)
while True:
await asyncio.sleep(self.heart_beat)
self.writer.write(queryb)
if self.debug is True:
print("ping")
# ------------------------读取response操作-----------------------
async def _response_handler(self):
"""监听响应数据的协程函数.`connect`被调用后会被创建为一个协程并放入事件循环."""
if self.debug is True:
if self.debug is True:
print("listenning response!")
while True:
try:
res = await self.reader.readuntil(self.SEPARATOR)
except:
raise
else:
response = self.decoder(res)
self._status_code_check(response)
def _status_code_check(self, response: Dict[str, Any]):
"""检查响应码并进行对不同的响应进行处理.
主要包括:
+ 编码在500~599段为服务异常,直接抛出对应异常
+ 编码在400~499段为调用异常,为对应ID的future设置异常
+ 编码在300~399段为警告,会抛出对应警告
+ 编码在200~399段为执行成功响应,将结果设置给对应ID的future.
+ 编码在100~199段为服务器响应,主要是处理验证响应和心跳响应
Parameters:
response (Dict[str, Any]): - 响应的python字典形式数据
Return:
(bool): - 如果是非服务异常类的响应,那么返回True
"""
code = response.get("CODE")
if self.debug:
print("resv:{}".format(response))
print(code)
if code >= 500:
if self.debug:
print("server error")
self._server_error_handler(code)
elif 500 > code >= 400:
if self.debug:
print("call method error")
self._method_error_handler(response)
elif 400 > code >= 200:
if code >= 300:
self._warning_handler(code)
if code in (200, 201, 202, 206, 300, 301):
if self.debug is True:
print("resv resp {}".format(response))
self._method_response_handler(response)
elif 200 > code >= 100:
self._server_response_handler(response)
else:
raise MprpcException("unknow status code {}".format(code))
return True
def _server_error_handler(self, code: int):
"""处理500~599段状态码,抛出对应警告.
Parameters:
(code): - 响应的状态码
Return:
(bool): - 已知的警告类型则返回True,否则返回False
Raise:
(ServerException): - 当返回为服务异常时则抛出对应异常
"""
if code == 501:
self._login_fut.set_result(False)
else:
self.clean()
raise abort(code)
return True
def _method_error_handler(self, response: Dict[str, Any]):
"""处理400~499段状态码,为对应的任务设置异常.
Parameters:s
(response): - 响应的python字典形式数据
Return:
(bool): - 准确地说没有错误就会返回True
"""
exp = response.get('MESSAGE')
code = response.get("CODE")
ID = exp.get("ID")
e = abort(code, ID=ID, message=exp.get('MESSAGE'))
self.tasks[ID].set_exception(e)
return True
def _warning_handler(self, code: int):
"""处理300~399段状态码,抛出对应警告.
Parameters:
(code): - 响应的状态码
Return:
(bool): - 已知的警告类型则返回True,否则返回False
"""
if code == 300:
warnings.warn(
"ExpireWarning",
RuntimeWarning,
stacklevel=3
)
elif code == 301:
warnings.warn(
"ExpireStreamWarning",
RuntimeWarning,
stacklevel=3
)
else:
if self.debug:
print("unknow code {}".format(code))
return False
return True
def _method_response_handler(self, response: Dict[str, Any]):
"""处理200~399段状态码,为对应的响应设置结果.
Parameters:
(response): - 响应的python字典形式数据
Return:
(bool): - 准确地说没有错误就会返回True
"""
code = response.get("CODE")
if code in (200, 300):
self._result_handler(response)
else:
asyncio.ensure_future(self._gen_result_handler(response))
# self.gen_result_handler(response)
def _server_response_handler(self, response: Dict[str, Any]):
"""处理100~199段状态码,针对不同的服务响应进行操作.
Parameters:
(response): - 响应的python字典形式数据
Return:
(bool): - 准确地说没有错误就会返回True
"""
code = response.get("CODE")
if code == 100:
if self.debug:
print("auth succeed")
self._login_fut.set_result(response)
if code == 101:
if self.debug:
print('pong')
return True
# ---------------------应答结果响应处理-------------------
def _result_handler(self, response: Dict[str, Any]):
"""应答结果响应处理.
将结果解析出来设置给任务对应的Future对象上
Parameters:
(response): - 响应的python字典形式数据
Return:
(bool): - 准确地说没有错误就会返回True
"""
res = response.get("MESSAGE")
ID = res.get("ID")
result = res.get("RESULT")
fut = self.tasks.get(ID)
fut.set_result(result)
return True
# -----------------------流式结果响应处理-------------------
async def _gen_result_handler(self, response: Dict[str, Any]):
"""流式结果响应处理.
+ 收到状态码标识201或301的响应后,将tasks中ID对应的Future对象的结果设置为一个用于包装的异步生成器.
并为这个ID创建一个异步队列保存在`_gens_queue[ID]`中用于存取结果
+ 收到状态码标识为202的响应后向对应ID的存取队列中存入一条结果.
+ 收到终止状态码206后向对应ID的异步生成器结果获取队列中存入一个`StopAsyncIteration`对象用于终止异步迭代器
Parameters:
(response): - 响应的python字典形式数据
Return:
(bool): - 准确地说没有错误就会返回True
"""
code = response.get("CODE")
res = response.get("MESSAGE")
ID = res.get("ID")
if code in (201, 301):
ait = self._wrap_gen(ID)
self.tasks.get(ID).set_result(ait)
self._gens_queue[ID] = asyncio.Queue()
if code == 202:
result = res.get('RESULT')
await self._gens_queue[ID].put(result)
if code == 206:
await self._gens_queue[ID].put(StopAsyncIteration())
return True
async def _wrap_gen(self, ID: str):
"""异步迭代器包装.
Parameters:
ID (str): - 任务ID
Yield:
(Any): - 从异步迭代器结果队列中获取的结果
Raise:
(StopAsyncIteration): - 异步迭代器终止时抛出该异常
"""
while True:
result = await self._gens_queue[ID].get()
if isinstance(result, StopAsyncIteration):
del self._gens_queue[ID]
break
else:
yield result
# --------------------------发送请求--------------------------------------
def _make_query(self, ID: str, methodname: str, returnable: bool, *args: Any, **kwargs: Any):
"""将调用请求的ID,方法名,参数包装为请求数据.
Parameters:
ID (str): - 任务ID
methodname (str): - 要调用的方法名
returnable (bool): - 是否要求返回结果
args (Any): - 要调用的方法的位置参数
kwargs (Any): - 要调用的方法的关键字参数
Return:
(Dict[str, Any]) : - 请求的python字典形式
"""
query = {
"MPRPC": self.VERSION,
"ID": ID,
"METHOD": methodname,
"RETURN": returnable,
"ARGS": args,
"KWARGS": kwargs
}
print(query)
return query
def _send_query(self, query: Dict[str, Any]):
"""将请求编码为字节串发送出去给服务端.
Parameters:
(query): - 请求的的python字典形式数据
Return:
(bool): - 准确地说没有错误就会返回True
"""
queryb = self.encoder(query)
self.writer.write(queryb)
if self.debug is True:
print("send query {}".format(queryb))
return True
[docs] def send_query(self, ID, methodname, returnable, *args, **kwargs):
"""将调用请求的ID,方法名,参数包装为请求数据后编码为字节串发送出去.
Parameters:
ID (str): - 任务ID
methodname (str): - 要调用的方法名
returnable (bool): - 是否要求返回结果
args (Any): - 要调用的方法的位置参数
kwargs (Any): - 要调用的方法的关键字参数
Return:
(bool): - 准确地说没有错误就会返回True
"""
query = self._make_query(ID, methodname, returnable, *args, **kwargs)
self._send_query(query)
self.tasks[ID] = self.loop.create_future()
return True
[docs] def delay(self, methodname, *args, **kwargs):
"""调用但不要求返回结果,而是通过系统方法getresult来获取.
Parameters:
methodname (str): - 要调用的方法名
args (Any): - 要调用的方法的位置参数
kwargs (Any): - 要调用的方法的关键字参数
"""
ID = str(uuid.uuid4())
self.send_query(ID, methodname, False, *args, **kwargs)
return ID
def _async_query(self, ID, methodname, *args, **kwargs):
"""将调用请求的ID,方法名,参数包装为请求数据后编码为字节串发送出去.并创建一个Future对象占位.
Parameters:
ID (str): - 任务ID
methodname (str): - 要调用的方法名
args (Any): - 要调用的方法的位置参数
kwargs (Any): - 要调用的方法的关键字参数
Return:
(asyncio.Future): - 返回对应ID的Future对象
"""
self.send_query(ID, methodname, True, *args, **kwargs)
task = self.tasks[ID]
return task
[docs] def async_query(self, methodname, *args, **kwargs):
"""异步调用一个远端的函数.
为调用创建一个ID,并将调用请求的方法名,参数包装为请求数据后编码为字节串发送出去.并创建一个Future对象占位.
Parameters:
methodname (str): - 要调用的方法名
args (Any): - 要调用的方法的位置参数
kwargs (Any): - 要调用的方法的关键字参数
Return:
(asyncio.Future): - 返回对应ID的Future对象
"""
ID = str(uuid.uuid4())
task = self._async_query(ID=ID, methodname=methodname, *args, **kwargs)
return task
def __getattr__(self, name):
"""运算符`.`重载,让远程函数调用可以使用`.`符号设置要调用的函数."""
ID = str(uuid.uuid4())
print(name)
return Method(self._async_query, name, ID)