Skip to content
Snippets Groups Projects
Commit db5c3558 authored by KERDREUX Jerome's avatar KERDREUX Jerome
Browse files

Added some extra type check

Body should be a dict or none .. no more sloopy stuff
parent 890d18a0
No related branches found
No related tags found
1 merge request!1First try of type hints
...@@ -59,10 +59,10 @@ class TestEngine(unittest.TestCase): ...@@ -59,10 +59,10 @@ class TestEngine(unittest.TestCase):
target = Device("test.basic", tools.get_random_uuid()) target = Device("test.basic", tools.get_random_uuid())
def action_1(): def action_1():
return "action_1" return {"value":"action_1"}
def action_2(_value=None): def action_2(_value=None):
return "action_%s" % _value return {"value":"action_%s" % _value}
def action_3(): def action_3():
raise Exception raise Exception
...@@ -77,12 +77,12 @@ class TestEngine(unittest.TestCase): ...@@ -77,12 +77,12 @@ class TestEngine(unittest.TestCase):
# simple test method # simple test method
msg.action = "action_1" msg.action = "action_1"
result = engine.run_action(msg, target) result = engine.run_action(msg, target)
self.assertEqual(result, "action_1") self.assertEqual(result, {"value":"action_1"})
# test with value # test with value
msg.action = "action_2" msg.action = "action_2"
msg.body = {"value": "2"} msg.body = {"value": "2"}
result = engine.run_action(msg, target) result = engine.run_action(msg, target)
self.assertEqual(result, "action_2") self.assertEqual(result, {"value":"action_2"})
# Exception in method # Exception in method
msg.action = "action_3" msg.action = "action_3"
with self.assertRaises(engine.XAALError): with self.assertRaises(engine.XAALError):
......
...@@ -259,7 +259,7 @@ class AsyncEngine(core.EngineMixin): ...@@ -259,7 +259,7 @@ class AsyncEngine(core.EngineMixin):
# process alives every 10 seconds # process alives every 10 seconds
self.add_timer(self.process_alives, 10) self.add_timer(self.process_alives, 10)
async def stop(self): async def stop(self): # pyright: ignore
logger.info("Stopping engine") logger.info("Stopping engine")
await self.run_hooks(HookType.stop) await self.run_hooks(HookType.stop)
self.running_event.clear() self.running_event.clear()
...@@ -391,8 +391,7 @@ async def console(locals=locals(), port: Optional[int] = None): ...@@ -391,8 +391,7 @@ async def console(locals=locals(), port: Optional[int] = None):
# let's find a free port if not specified # let's find a free port if not specified
def find_free_port(): def find_free_port():
import socketserver import socketserver
with socketserver.TCPServer(('localhost', 0), None) as s: # pyright: ignore pyright reject the None here
with socketserver.TCPServer(('localhost', 0), None) as s:
return s.server_address[1] return s.server_address[1]
port = find_free_port() port = find_free_port()
...@@ -408,6 +407,6 @@ async def console(locals=locals(), port: Optional[int] = None): ...@@ -408,6 +407,6 @@ async def console(locals=locals(), port: Optional[int] = None):
# start the console # start the console
try: try:
# debian with ipv6 disabled still state that localhost is ::1, which broke aioconsole # debian with ipv6 disabled still state that localhost is ::1, which broke aioconsole
await aioconsole.start_interactive_server(host="127.0.0.1", port=port, factory=factory, banner=banner) await aioconsole.start_interactive_server(host="127.0.0.1", port=port, factory=factory, banner=banner) # pyright: ignore
except OSError: except OSError:
logger.warning("Unable to run console") logger.warning("Unable to run console")
...@@ -29,7 +29,7 @@ class AsyncNetworkConnector(object): ...@@ -29,7 +29,7 @@ class AsyncNetworkConnector(object):
try: try:
# Linux + MacOS + BSD # Linux + MacOS + BSD
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except: except Exception:
# Windows # Windows
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((self.bind_addr, self.port)) sock.bind((self.bind_addr, self.port))
......
...@@ -25,6 +25,8 @@ import typing ...@@ -25,6 +25,8 @@ import typing
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional
from .config import config from .config import config
from . import core from . import core
from .exceptions import CallbackError, MessageParserError, XAALError from .exceptions import CallbackError, MessageParserError, XAALError
...@@ -115,8 +117,8 @@ class Engine(core.EngineMixin): ...@@ -115,8 +117,8 @@ class Engine(core.EngineMixin):
request for each targets identied in the engine request for each targets identied in the engine
""" """
if not msg.is_request(): if not msg.is_request():
return return
targets = core.filter_msg_for_devices(msg, self.devices) targets = core.filter_msg_for_devices(msg, self.devices)
for target in targets: for target in targets:
if msg.is_request_isalive(): if msg.is_request_isalive():
...@@ -240,4 +242,8 @@ def run_action(msg: 'Message', device: 'Device') -> Optional[dict]: ...@@ -240,4 +242,8 @@ def run_action(msg: 'Message', device: 'Device') -> Optional[dict]:
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
raise XAALError("Error in method:%s params:%s" % (msg.action, params)) raise XAALError("Error in method:%s params:%s" % (msg.action, params))
# Here result should be None or a dict, and we need to enforce that. This will cause issue
# in send_reply otherwise.
if result is not None and not isinstance(result, dict):
raise XAALError("Method %s should return a dict or None" % msg.action)
return result return result
...@@ -22,7 +22,6 @@ import datetime ...@@ -22,7 +22,6 @@ import datetime
import logging import logging
import pprint import pprint
import struct import struct
import sys
import typing import typing
from enum import Enum from enum import Enum
from typing import Any, Optional from typing import Any, Optional
...@@ -31,8 +30,8 @@ import pysodium ...@@ -31,8 +30,8 @@ import pysodium
from tabulate import tabulate from tabulate import tabulate
from . import cbor, tools from . import cbor, tools
from .config import config
from .bindings import UUID from .bindings import UUID
from .config import config
from .exceptions import MessageError, MessageParserError from .exceptions import MessageError, MessageParserError
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
...@@ -213,7 +212,7 @@ class MessageFactory(object): ...@@ -213,7 +212,7 @@ class MessageFactory(object):
# Decode cbor incoming data # Decode cbor incoming data
try: try:
data_rx = cbor.loads(data) data_rx = cbor.loads(data)
except: except Exception:
raise MessageParserError("Unable to parse CBOR data") raise MessageParserError("Unable to parse CBOR data")
# Instanciate Message, parse the security layer # Instanciate Message, parse the security layer
...@@ -251,13 +250,13 @@ class MessageFactory(object): ...@@ -251,13 +250,13 @@ class MessageFactory(object):
nonce = build_nonce(msg.timestamp) nonce = build_nonce(msg.timestamp)
try: try:
clear = pysodium.crypto_aead_chacha20poly1305_ietf_decrypt(ciph, ad, nonce, self.cipher_key) clear = pysodium.crypto_aead_chacha20poly1305_ietf_decrypt(ciph, ad, nonce, self.cipher_key)
except: except Exception:
raise MessageParserError("Unable to decrypt msg") raise MessageParserError("Unable to decrypt msg")
# Decode application layer (payload) # Decode application layer (payload)
try: try:
payload = cbor.loads(clear) payload = cbor.loads(clear)
except: except Exception:
raise MessageParserError("Unable to parse CBOR data in payload after decrypt") raise MessageParserError("Unable to parse CBOR data in payload after decrypt")
try: try:
msg.source = UUID(bytes=payload[0]) msg.source = UUID(bytes=payload[0])
...@@ -345,12 +344,15 @@ def build_timestamp() -> tuple: ...@@ -345,12 +344,15 @@ def build_timestamp() -> tuple:
"""Return array [seconds since epoch, microseconds since last seconds] Time = UTC+0000""" """Return array [seconds since epoch, microseconds since last seconds] Time = UTC+0000"""
epoch = datetime.datetime.fromtimestamp(0, datetime.UTC) epoch = datetime.datetime.fromtimestamp(0, datetime.UTC)
timestamp = datetime.datetime.now(datetime.UTC) - epoch timestamp = datetime.datetime.now(datetime.UTC) - epoch
return _packtimestamp(timestamp.total_seconds(), timestamp.microseconds) return (int(timestamp.total_seconds()), int(timestamp.microseconds))
## This stuff below is for Py2/Py3 compatibility. In the current state of xAAL, we only use
# Py3. This code is here for archive purpose and could be removed in the future.
# for better performance, I choose to use this trick to fix the change in size for Py3. # for better performance, I choose to use this trick to fix the change in size for Py3.
# only test once. # only test once.
if sys.version_info.major == 2: # if sys.version_info.major == 2:
_packtimestamp = lambda t1, t2: (long(t1), int(t2)) # _packtimestamp = lambda t1, t2: (long(t1), int(t2)) # pyright: ignore
else: # else:
_packtimestamp = lambda t1, t2: (int(t1), int(t2)) # _packtimestamp = lambda t1, t2: (int(t1), int(t2))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment