Skip to content
Snippets Groups Projects
Commit ae8e6427 authored by KERDREUX Jerome's avatar KERDREUX Jerome
Browse files

Massive type hinting / Format / import ordering

parent bcda1261
Branches
No related tags found
1 merge request!1First try of type hints
import asyncio
from . import core
from . import config
from . import tools
from .messages import MessageParserError
from .aionetwork import AsyncNetworkConnector
from .exceptions import XAALError, CallbackError
import logging
import signal
import sys
import time
from enum import Enum
from pprint import pprint
import aioconsole
import signal
import sys
from tabulate import tabulate
from pprint import pprint
from . import config, core, tools
from .aionetwork import AsyncNetworkConnector
from .exceptions import CallbackError, XAALError
from .messages import MessageParserError
import logging
logger = logging.getLogger(__name__)
class AsyncEngine(core.EngineMixin):
......@@ -132,7 +129,7 @@ class AsyncEngine(core.EngineMixin):
async def handle_action_request(self, msg, target):
try:
result = await run_action(msg, target)
if result is None:
if result is not None:
self.send_reply(dev=target,targets=[msg.source],action=msg.action,body=result)
except CallbackError as e:
self.send_error(target, e.code, e.description)
......@@ -143,7 +140,7 @@ class AsyncEngine(core.EngineMixin):
# Asyncio loop & Tasks
#####################################################
def get_loop(self):
if self._loop == None:
if self._loop is None:
logger.debug('New event loop')
self._loop = asyncio.get_event_loop()
return self._loop
......
import asyncio
import struct
import socket
import logging
import socket
import struct
logger = logging.getLogger(__name__)
class AsyncNetworkConnector(object):
def __init__(self, addr, port, hops, bind_addr="0.0.0.0"):
def __init__(self, addr: str, port: int, hops: int, bind_addr="0.0.0.0"):
self.addr = addr
self.port = port
self.hops = hops
......@@ -34,21 +33,19 @@ class AsyncNetworkConnector(object):
# Windows
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((self.bind_addr, self.port))
mreq = struct.pack(
"=4s4s", socket.inet_aton(self.addr), socket.inet_aton(self.bind_addr)
)
mreq = struct.pack("=4s4s", socket.inet_aton(self.addr), socket.inet_aton(self.bind_addr))
sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq)
sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 10)
sock.setblocking(False)
return sock
def send(self, data):
def send(self, data: bytes):
self.protocol.datagram_send(data, self.addr, self.port)
def receive(self, data):
def receive(self, data: bytes):
self._rx_queue.put_nowait(data)
async def get_data(self):
async def get_data(self) -> bytes:
return await self._rx_queue.get()
......
......@@ -18,20 +18,38 @@
# along with xAAL. If not, see <http://www.gnu.org/licenses/>.
#
from .messages import MessageType, MessageAction, MessageFactory, ALIVE_ADDR
import inspect
import logging
import time
import typing
from typing import Any, Optional, List
from .exceptions import EngineError, XAALError
from .messages import ALIVE_ADDR, MessageAction, MessageFactory, MessageType
if typing.TYPE_CHECKING:
from .devices import Device, Attribute
from .messages import Message
import time
import inspect
import logging
logger = logging.getLogger(__name__)
class EngineMixin(object):
__slots__ = ['devices','timers','subscribers','msg_filter','_attributesChange','network','msg_factory']
#####################################################
# Timer class
#####################################################
class Timer(object):
def __init__(self, func, period, counter):
self.func = func
self.period = period
self.counter = counter
self.deadline = time.time() + period
def __init__(self,address,port,hops,key):
class EngineMixin(object):
__slots__ = ["devices", "timers", "subscribers", "msg_filter", "_attributesChange", "network", "msg_factory"]
def __init__(self, address: int, port: int, hops: int, key: str):
self.devices = [] # list of devices / use (un)register_devices()
self.timers = [] # functions to call periodic
self.subscribers = [] # message receive workflow
......@@ -49,7 +67,7 @@ class EngineMixin(object):
#####################################################
# Devices management
#####################################################
def add_device(self, dev):
def add_device(self, dev: "Device"):
"""register a new device"""
if dev not in self.devices:
self.devices.append(dev)
......@@ -57,12 +75,12 @@ class EngineMixin(object):
if self.is_running():
self.send_alive(dev)
def add_devices(self, devs):
def add_devices(self, devs: List["Device"]):
"""register new devices"""
for dev in devs:
self.add_device(dev)
def remove_device(self, dev):
def remove_device(self, dev: "Device"):
"""unregister a device"""
dev.engine = None
# Remove dev from devices list
......@@ -72,54 +90,62 @@ class EngineMixin(object):
# xAAL messages Tx handling
#####################################################
# Fifo for msg to send
def queue_msg(self, msg):
def queue_msg(self, msg: bytes):
logger.critical("To be implemented queue_msg: %s", msg)
def send_request(self, dev, targets, action, body=None):
def send_request(self, dev: "Device", targets: list, action: str, body: Optional[dict] = None):
"""queue a new request"""
msg = self.msg_factory.build_msg(dev, targets, MessageType.REQUEST, action, body)
self.queue_msg(msg)
def send_reply(self, dev, targets, action, body=None):
def send_reply(self, dev: "Device", targets: list, action: str, body: Optional[dict] = None):
"""queue a new reply"""
msg = self.msg_factory.build_msg(dev, targets, MessageType.REPLY, action, body)
self.queue_msg(msg)
def send_error(self, dev, errcode, description=None):
def send_error(self, dev: "Device", errcode: int, description: Optional[str] = None):
"""queue a error message"""
msg = self.msg_factory.build_error_msg(dev, errcode, description)
self.queue_msg(msg)
def send_get_description(self, dev, targets):
def send_get_description(self, dev: "Device", targets: list):
"""queue a get_description request"""
self.send_request(dev, targets, MessageAction.GET_DESCRIPTION.value)
def send_get_attributes(self, dev, targets):
def send_get_attributes(self, dev: "Device", targets: list):
"""queue a get_attributes request"""
self.send_request(dev, targets, MessageAction.GET_ATTRIBUTES.value)
def send_notification(self, dev, action, body=None):
def send_notification(self, dev: "Device", action: str, body: Optional[dict] = None):
"""queue a notificaton"""
msg = self.msg_factory.build_msg(dev, [], MessageType.NOTIFY, action, body)
self.queue_msg(msg)
def send_alive(self, dev):
def send_alive(self, dev: "Device"):
"""Send a Alive message for a given device"""
timeout = dev.get_timeout()
msg = self.msg_factory.build_alive_for(dev, timeout)
self.queue_msg(msg)
dev.update_alive()
def send_is_alive(self, dev, targets=[ALIVE_ADDR,], dev_types=["any.any",]):
def send_is_alive(
self,
dev: "Device",
targets: list = [
ALIVE_ADDR,
],
dev_types: list = [
"any.any",
],
):
"""Send a is_alive message, w/ dev_types filtering"""
body = {'dev_types': dev_types}
body = {"dev_types": dev_types}
self.send_request(dev, targets, MessageAction.IS_ALIVE.value, body)
#####################################################
# Messages filtering
#####################################################
def enable_msg_filter(self, func=None):
def enable_msg_filter(self, func: Any = None):
"""enable message filter"""
self.msg_filter = func or self.default_msg_filter
......@@ -127,7 +153,7 @@ class EngineMixin(object):
"""disable message filter"""
self.msg_filter = None
def default_msg_filter(self, msg):
def default_msg_filter(self, msg: "Message"):
"""
Filter messages:
- check if message has alive request address
......@@ -156,11 +182,11 @@ class EngineMixin(object):
#####################################################
# xAAL attributes changes
#####################################################
def add_attributes_change(self, attr):
def add_attributes_change(self, attr: "Attribute"):
"""add a new attribute change to the list"""
self._attributesChange.append(attr)
def get_attributes_change(self):
def get_attributes_change(self) -> List["Attribute"]:
"""return the pending attributes changes list"""
return self._attributesChange
......@@ -180,16 +206,16 @@ class EngineMixin(object):
#####################################################
# xAAL messages subscribers
#####################################################
def subscribe(self,func):
def subscribe(self, func: Any):
self.subscribers.append(func)
def unsubscribe(self,func):
def unsubscribe(self, func: Any):
self.subscribers.remove(func)
#####################################################
# timers
#####################################################
def add_timer(self, func, period,counter=-1):
def add_timer(self, func: Any, period: int, counter: int = -1):
"""
func: function to call
period: period in second
......@@ -201,7 +227,7 @@ class EngineMixin(object):
self.timers.append(t)
return t
def remove_timer(self, timer):
def remove_timer(self, timer: Timer):
"""remove a given timer from the list"""
self.timers.remove(timer)
......@@ -224,20 +250,11 @@ class EngineMixin(object):
logger.critical("To be implemented is_running")
return False
#####################################################
# Timer class
#####################################################
class Timer(object):
def __init__(self, func, period, counter):
self.func = func
self.period = period
self.counter = counter
self.deadline = time.time() + period
#####################################################
# Usefull functions to Engine developpers
#####################################################
def filter_msg_for_devices(msg, devices):
def filter_msg_for_devices(msg: "Message", devices: List["Device"]):
"""
loop throught the devices, to find which are expected w/ the msg
- Filter on dev_types for is_alive broadcast request.
......@@ -247,13 +264,13 @@ def filter_msg_for_devices(msg, devices):
if msg.is_request_isalive() and (ALIVE_ADDR in msg.targets):
# if we receive a broadcast is_alive request, we reply
# with filtering on dev_tyes.
if 'dev_types' in msg.body.keys():
dev_types = msg.body['dev_types']
if 'any.any' in dev_types:
if "dev_types" in msg.body.keys():
dev_types = msg.body["dev_types"]
if "any.any" in dev_types:
results = devices
else:
for dev in devices:
any_subtype = dev.dev_type.split('.')[0] + '.any'
any_subtype = dev.dev_type.split(".")[0] + ".any"
if dev.dev_type in dev_types:
results.append(dev)
elif any_subtype in dev_types:
......@@ -267,7 +284,8 @@ def filter_msg_for_devices(msg, devices):
results.append(dev)
return results
def search_action(msg, device):
def search_action(msg: "Message", device: "Device"):
"""
Extract an action (match with methods) from a msg on the device.
Return:
......@@ -287,7 +305,7 @@ def search_action(msg, device):
body_params = msg.body
for k in body_params:
temp = '_%s' %k
temp = "_%s" % k
if temp in method_params:
params.update({temp: body_params[k]})
else:
......@@ -297,13 +315,12 @@ def search_action(msg, device):
raise XAALError("Method %s not found on device %s" % (msg.action, device))
return result
def get_args_method(method):
def get_args_method(method: Any) -> List[str]:
"""return the list on arguments for a given python method"""
spec = inspect.getfullargspec(method)
try:
spec.args.remove('self')
spec.args.remove("self")
except Exception:
pass
return spec.args
......@@ -7,18 +7,22 @@ import logging
import logging.handlers
import os
import time
from typing import Optional, Any
from typing import Any, Optional
import coloredlogs
from decorator import decorator
from . import config
def singleton(class_):
instances = {}
def getinstance(*args, **kwargs):
if class_ not in instances:
instances[class_] = class_(*args, **kwargs)
return instances[class_]
return getinstance
......@@ -28,28 +32,32 @@ def timeit(method, *args, **kwargs):
ts = time.time()
result = method(*args, **kwargs)
te = time.time()
logger.debug('%r (%r, %r) %2.6f sec' % (method.__name__, args, kwargs, te-ts))
logger.debug("%r (%r, %r) %2.6f sec" % (method.__name__, args, kwargs, te - ts))
return result
def set_console_title(value: str):
# set xterm title
print("\x1B]0;xAAL => %s\x07" % value, end='\r')
print("\x1b]0;xAAL => %s\x07" % value, end="\r")
def setup_console_logger(level: str = config.log_level):
fmt = '%(asctime)s %(name)-25s %(funcName)-18s %(levelname)-8s %(message)s'
fmt = "%(asctime)s %(name)-25s %(funcName)-18s %(levelname)-8s %(message)s"
# fmt = '[%(name)s] %(funcName)s %(levelname)s: %(message)s'
coloredlogs.install(level=level, fmt=fmt)
def setup_file_logger(name: str, level: str = config.log_level, filename: Optional[str] = None):
filename = filename or os.path.join(config.log_path,'%s.log' % name)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s')
handler = logging.handlers.RotatingFileHandler(filename, 'a', 10000, 1, 'utf8')
filename = filename or os.path.join(config.log_path, "%s.log" % name)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s")
handler = logging.handlers.RotatingFileHandler(filename, "a", 10000, 1, "utf8")
handler.setLevel(level)
handler.setFormatter(formatter)
# register the new handler
logger = logging.getLogger(name)
logger.root.addHandler(handler)
logger.root.setLevel('DEBUG')
logger.root.setLevel("DEBUG")
# ---------------------------------------------------------------------------
# TBD: We should merge this stuffs, and add support for default config file
......@@ -64,9 +72,10 @@ def run_package(pkg_name: str, pkg_setup: Any, console_log: bool = True, file_lo
if file_log:
setup_file_logger(pkg_name)
logger = logging.getLogger(pkg_name)
logger.info('starting xaal package: %s'% pkg_name )
logger.info("starting xaal package: %s" % pkg_name)
from .engine import Engine
eng = Engine()
result = pkg_setup(eng)
......@@ -78,4 +87,5 @@ def run_package(pkg_name: str, pkg_setup: Any, console_log: bool = True, file_lo
eng.shutdown()
logger.info("exit")
__all__ = ['singleton','timeit','set_console_title','setup_console_logger','setup_file_logger','run_package']
__all__ = ["singleton", "timeit", "set_console_title", "setup_console_logger", "setup_file_logger", "run_package"]
......@@ -19,11 +19,12 @@
#
import datetime
import logging
import pprint
import struct
import sys
from enum import Enum
import typing
from enum import Enum
from typing import Any, Optional
import pysodium
......@@ -37,8 +38,6 @@ if typing.TYPE_CHECKING:
from .devices import Device
import logging
logger = logging.getLogger(__name__)
ALIVE_ADDR = UUID("00000000-0000-0000-0000-000000000000")
......@@ -277,7 +276,7 @@ class MessageFactory(object):
#####################################################
def build_msg(
self,
dev: Optional['Device'] = None,
dev: Optional["Device"] = None,
targets: list = [],
msg_type: Optional[MessageType] = None,
action: Optional[str] = None,
......@@ -309,30 +308,31 @@ class MessageFactory(object):
data = self.encode_msg(message)
return data
def build_alive_for(self, dev: 'Device', timeout: int =0) -> bytes:
def build_alive_for(self, dev: "Device", timeout: int = 0) -> bytes:
"""Build Alive message for a given device
timeout = 0 is the minimum value
"""
body = {}
body['timeout'] = timeout
message = self.build_msg(dev=dev, targets=[], msg_type=MessageType.NOTIFY, action=MessageAction.ALIVE.value, body=body)
body["timeout"] = timeout
message = self.build_msg(
dev=dev, targets=[], msg_type=MessageType.NOTIFY, action=MessageAction.ALIVE.value, body=body
)
return message
def build_error_msg(self, dev: 'Device', errcode:int , description: Optional[str] =None):
def build_error_msg(self, dev: "Device", errcode: int, description: Optional[str] = None):
"""Build a Error message"""
message = Message()
body = {}
body['code'] = errcode
body["code"] = errcode
if description:
body['description'] = description
body["description"] = description
message = self.build_msg(dev, [], MessageType.NOTIFY, "error", body)
return message
def build_nonce(data: tuple) -> bytes:
"""Big-Endian, time in seconds and time in microseconds"""
nonce = struct.pack('>QL', data[0], data[1])
nonce = struct.pack(">QL", data[0], data[1])
return nonce
......
......@@ -18,12 +18,13 @@
# along with xAAL. If not, see <http://www.gnu.org/licenses/>.
#
import logging
import select
import socket
import struct
import select
import logging
import time
from enum import Enum
from typing import Optional
logger = logging.getLogger(__name__)
......@@ -36,7 +37,7 @@ class NetworkState(Enum):
class NetworkConnector(object):
UDP_MAX_SIZE = 65507
def __init__(self, addr, port, hops,bind_addr='0.0.0.0'):
def __init__(self, addr: str, port: int, hops: int, bind_addr="0.0.0.0"):
self.addr = addr
self.port = port
self.hops = hops
......@@ -70,17 +71,17 @@ class NetworkConnector(object):
def is_connected(self):
return self.state == NetworkState.connected
def receive(self):
def receive(self) -> bytes:
packt = self.__sock.recv(self.UDP_MAX_SIZE)
return packt
def __get_data(self):
def __get_data(self) -> Optional[bytes]:
r = select.select([self.__sock,], [], [], 0.02)
if r[0]:
return self.receive()
return None
def get_data(self):
def get_data(self) -> Optional[bytes]:
if not self.is_connected():
self.connect()
try:
......@@ -88,7 +89,7 @@ class NetworkConnector(object):
except Exception as e:
self.network_error(e)
def send(self,data):
def send(self, data: bytes):
if not self.is_connected():
self.connect()
try:
......@@ -96,7 +97,7 @@ class NetworkConnector(object):
except Exception as e:
self.network_error(e)
def network_error(self, msg):
def network_error(self, ex: Exception):
self.disconnect()
logger.info("Network error, reconnect..%s" % msg)
logger.info("Network error, reconnect..%s" % ex.__str__())
time.sleep(5)
......@@ -18,39 +18,42 @@
# along with xAAL. If not, see <http://www.gnu.org/licenses/>.
#
import functools
import os
import re
import sys
from typing import Optional, Union
import pysodium
import sys
import functools
from configobj import ConfigObj
from . import config
from .bindings import UUID
XAAL_DEVTYPE_PATTERN = '^[a-zA-Z][a-zA-Z0-9_-]*\\.[a-zA-Z][a-zA-Z0-9_-]*$'
XAAL_DEVTYPE_PATTERN = "^[a-zA-Z][a-zA-Z0-9_-]*\\.[a-zA-Z][a-zA-Z0-9_-]*$"
def get_cfg_filename(name: str, cfg_dir: str = config.conf_dir) -> str:
if name.startswith('xaal.'):
if name.startswith("xaal."):
name = name[5:]
filename = '%s.ini' % name
filename = "%s.ini" % name
if not os.path.isdir(cfg_dir):
print("Your configuration directory doesn't exist: [%s]" % cfg_dir)
return os.path.join(cfg_dir, filename)
def load_cfg_file(filename: str) -> Optional[ConfigObj]:
"""load .ini file and return it as dict"""
if os.path.isfile(filename):
return ConfigObj(filename,indent_type=' ',encoding="utf8")
return ConfigObj(filename, indent_type=" ", encoding="utf8")
return None
def load_cfg(app_name: str) -> Optional[ConfigObj]:
filename = get_cfg_filename(app_name)
return load_cfg_file(filename)
def load_cfg_or_die(app_name: str) -> ConfigObj:
cfg = load_cfg(app_name)
if not cfg:
......@@ -58,19 +61,23 @@ def load_cfg_or_die(app_name: str) -> ConfigObj:
sys.exit(-1)
return cfg
def new_cfg(app_name: str) -> ConfigObj:
filename = get_cfg_filename(app_name)
cfg = ConfigObj(filename,indent_type=' ')
cfg['config'] = {}
cfg['config']['addr']=get_random_uuid().str
cfg = ConfigObj(filename, indent_type=" ")
cfg["config"] = {}
cfg["config"]["addr"] = get_random_uuid().str
return cfg
def get_random_uuid() -> UUID:
return UUID.random()
def get_random_base_uuid(digit=2) -> UUID:
return UUID.random_base(digit)
def get_uuid(val: Union[UUID, str, None]) -> Optional[UUID]:
if isinstance(val, UUID):
return val
......@@ -78,6 +85,7 @@ def get_uuid(val: Union[UUID, str, None]) -> Optional[UUID]:
return str_to_uuid(val)
return None
def str_to_uuid(val: str) -> Optional[UUID]:
"""return an xAAL address for a given string"""
try:
......@@ -85,18 +93,22 @@ def str_to_uuid(val: str) -> Optional[UUID]:
except ValueError:
return None
def bytes_to_uuid(val: bytes) -> Optional[UUID]:
try:
return UUID(bytes=val)
except ValueError:
return None
def is_valid_uuid(val: Union[UUID, str]) -> bool:
return isinstance(val, UUID)
def is_valid_address(val: Union[UUID, str]) -> bool:
return is_valid_uuid(val)
@functools.lru_cache(maxsize=128)
def is_valid_dev_type(val: str) -> bool:
if not isinstance(val, str):
......@@ -105,6 +117,7 @@ def is_valid_dev_type(val: str) -> bool:
return True
return False
def pass2key(passphrase: str) -> bytes:
"""Generate key from passphrase using libsodium
crypto_pwhash_scryptsalsa208sha256 func
......@@ -112,19 +125,20 @@ def pass2key(passphrase: str) -> bytes:
opslimit: crypto_pwhash_scryptsalsa208sha256_OPSLIMIT_INTERACTIVE
memlimit: crypto_pwhash_scryptsalsa208sha256_MEMLIMIT_INTERACTIVE
"""
buf = passphrase.encode('utf-8')
buf = passphrase.encode("utf-8")
KEY_BYTES = pysodium.crypto_pwhash_scryptsalsa208sha256_SALTBYTES # 32
# this should be:
# salt = bytes(KEY_BYTES)
# but due to bytes() stupid stuff in py2 we need this awfull stuff
salt = ('\00' * KEY_BYTES).encode('utf-8')
salt = ("\00" * KEY_BYTES).encode("utf-8")
opslimit = pysodium.crypto_pwhash_scryptsalsa208sha256_OPSLIMIT_INTERACTIVE
memlimit = pysodium.crypto_pwhash_scryptsalsa208sha256_MEMLIMIT_INTERACTIVE
key = pysodium.crypto_pwhash_scryptsalsa208sha256(KEY_BYTES, buf, salt, opslimit, memlimit)
return key
@functools.lru_cache(maxsize=128)
def reduce_addr(addr: UUID) -> str:
"""return a string based addred without all digits"""
tmp = addr.str
return tmp[:5] + '..' + tmp[-5:]
return tmp[:5] + ".." + tmp[-5:]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment