Skip to content
Snippets Groups Projects
Commit 499b75a8 authored by Daniele Nicolodi's avatar Daniele Nicolodi
Browse files

Refactor RPC implementation

This makes the RPC code general and reusable.
parent 76f088d6
No related branches found
No related tags found
No related merge requests found
import zmq
import click
import pickle
import cryod
import rpc
@click.group()
@click.pass_context
def main(ctx):
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect(cryod.ADDRESS)
ctx.obj = socket
ctx.obj = rpc.Client(cryod.ADDRESS, cryod.TIMEOUT)
@main.command()
@click.pass_obj
@click.argument('state', type=click.Choice(('on', 'off')))
def power(socket, state):
socket.send_multipart((b'POWER', pickle.dumps(int(state == 'on'))))
if not socket.poll(cryod.TIMEOUT) & zmq.POLLIN:
raise TimeoutError
reply, args = socket.recv_multipart()
if reply != b'ACK' or pickle.loads(args) is not None:
def power(compressor, state):
r = compressor.power(int(state == 'on'))
if r is not None:
print(reply)
......
import asyncio
import click
import contextlib
import pickle
import zmq.asyncio
import eventloop
import cpa28xx
import rpc
ADDRESS = 'tcp://127.0.0.1:65534'
TIMEOUT = 3000 # milliseconds
class Server:
def __init__(self, address, compressor):
ctx = zmq.asyncio.Context()
self.socket = ctx.socket(zmq.REP)
self.socket.bind(address)
self.compressor = compressor
async def __call__(self):
while True:
request, args = await self.socket.recv_multipart()
func = getattr(self, request.decode('utf8').lower())
reply = func(pickle.loads(args))
await self.socket.send_multipart((b'ACK', pickle.dumps(reply)))
def power(self, state):
return self.compressor.power(state)
TIMEOUT = 3.0 # seconds
async def status(compressor, filename):
......@@ -45,7 +25,7 @@ async def status(compressor, filename):
@click.option('--device', '-d', metavar='DEV', help='Communication port.')
def main(device, filename):
compressor = cpa28xx.CPA28xx(device)
server = Server(ADDRESS, compressor)
server = rpc.Server(ADDRESS, compressor)
async def main():
await asyncio.gather(status(compressor, filename), server())
......
rpc.py 0 → 100644
import functools
import pickle
import zmq
import zmq.asyncio
class RPCError(RuntimeError):
pass
class RPCTimeoutError(TimeoutError):
pass
class Server:
"""Simple asynchronous RPC server."""
def __init__(self, address, obj):
ctx = zmq.asyncio.Context()
self.socket = ctx.socket(zmq.REP)
self.socket.bind(address)
self.obj = obj
async def __call__(self):
while True:
request, args = await self.socket.recv_multipart()
func = getattr(self.obj, request.decode('utf8'))
reply = func(*pickle.loads(args))
await self.socket.send_multipart((b'ACK', pickle.dumps(reply)))
class Client:
"""Simple synchronous RPC client."""
def __init__(self, address, timeout=1.0):
self.ctx = zmq.Context()
self.socket = self.ctx.socket(zmq.REQ)
self.socket.connect(address)
self.timeout = int(timeout * 1000)
def call(self, method, *args):
self.socket.send_multipart((method.encode('utf8'), pickle.dumps(args)))
if not self.socket.poll(self.timeout) & zmq.POLLIN:
raise RPCTimeoutError
reply, s = self.socket.recv_multipart()
values = pickle.loads(s)
if reply != b'ACK':
raise RPCError(s)
return values
def __getattr__(self, name):
return functools.partial(self.call, name)
import rpc
def main():
client = rpc.Client('tcp://127.0.0.1:1234')
n = 0
for _ in range(10):
r = client.call('inc', n)
assert r == n + 1
n = r
print(n)
for _ in range(10):
r = client.inc(n)
assert r == n + 1
n = r
print(n)
if __name__ == '__main__':
main()
import asyncio
import eventloop
import rpc
class Test:
def inc(self, n):
return n + 1
def main():
server = rpc.Server('tcp://127.0.0.1:1234', Test())
asyncio.set_event_loop_policy(eventloop.EventLoopPolicy())
asyncio.run(server())
if __name__ == '__main__':
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment