detect-gui/core/edge_context.py

473 lines
21 KiB
Python
Raw Normal View History

2024-11-21 11:39:52 +08:00
import asyncio
import importlib
import inspect
import json
import os
import pkgutil
import re
from typing import Type, Dict, Optional, Any, List
from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from httpx import delete
from starlette.responses import JSONResponse
from uvicorn import Server, Config
from core.edge_component import EdgeComponent, service
from core.config import settings
from core.edge_internal import Database, Scheduler
from core.edge_logger import LoggingMixin
import paho.mqtt.client as mqtt
from core.edge_routes import setup_routes
from core.edge_task import EdgeTask, EdgeTaskStep
from core.edge_util import camel_to_snake
MQTT_TELEMETRY_TOPIC = "v1/devices/me/telemetry"
MQTT_RPC_TOPIC_PREFIX = "v1/devices/me/rpc/request"
MQTT_RPC_TOPIC_RESP_PREFIX = "v1/devices/me/rpc/response"
MQTT_RPC_TOPIC = f"{MQTT_RPC_TOPIC_PREFIX}/+"
class EdgeContext(LoggingMixin):
def __init__(self):
super().__init__()
# 初始化mqtt客户端
self.client = mqtt.Client()
self.client.on_connect = self.on_connect
self.client.on_disconnect = self.on_disconnect
self.client.on_message = self.on_message
# 搜索边缘组件并实例化
self.components: Dict[str, EdgeComponent] = self._load_components()
self._init_internal_components()
self._init_components()
self.logger.info(f"当前边缘组件:{','.join(self.components.keys())}")
# loop
self.loop = None
# tcp server
self.tcp_server = None
self.tcp_clients = set() # 保存客户端连接
# rest
rest_port = settings.get("rest.port", 8000)
self.app = FastAPI()
# 挂载静态文件目录
local_dir = settings.get('rest.static', 'local')
os.makedirs(local_dir, exist_ok=True)
self.app.mount("/local", StaticFiles(directory=local_dir), name="local")
self.config = Config(app=self.app, host="0.0.0.0", port=rest_port)
self.server = Server(self.config)
setup_routes(self, self.app)
def _load_components(self) -> Dict[str, EdgeComponent]:
package_name = 'components'
components = {}
# 动态导入 components 目录下的所有模块
package = importlib.import_module(package_name)
for _, module_name, _ in pkgutil.iter_modules(package.__path__, package_name + "."):
module = importlib.import_module(module_name)
# 查找模块中的所有 EdgeComponent 子类
for name, cls in inspect.getmembers(module, inspect.isclass):
if issubclass(cls, EdgeComponent) and cls is not EdgeComponent:
# 确保是非抽象类且没有子类
if not inspect.isabstract(cls) and not cls.__subclasses__() and hasattr(cls, 'component_name'):
instance = cls(self)
components[cls.component_name] = instance
return components
def _init_internal_components(self):
database = Database(self)
scheduler = Scheduler(self)
database.configure(settings)
scheduler.configure(settings)
database.start()
scheduler.start()
self.components[database.__class__.component_name] = database
self.components[scheduler.__class__.component_name] = scheduler
def _init_components(self):
for componentKey in self.components.keys():
if componentKey in [Database.component_name, Scheduler.component_name]:
continue
component = self.components[componentKey]
component.configure(settings)
if component.__class__.component_auto_start:
try:
if not component.is_started:
component.start()
except Exception as e:
self.logger.error(f"{componentKey}启动失败", e)
def _get_non_abstract_subclasses(self, cls: Type[EdgeComponent]) -> Dict[str, EdgeComponent]:
subclasses = set(cls.__subclasses__())
non_abstract_subclasses = {}
for subclass in subclasses:
if not inspect.isabstract(subclass) and not subclass.__subclasses__() and hasattr(subclass, 'component_name'):
instance = subclass()
instance.configure(settings)
if subclass.component_auto_start:
instance.start()
non_abstract_subclasses[subclass.component_name] = instance
else:
non_abstract_subclasses.update(self._get_non_abstract_subclasses(subclass))
return non_abstract_subclasses
async def execute(self, func_name, *args, **kwargs):
func = getattr(self, func_name, None)
if func is None:
raise RuntimeError(f"Function {func_name} not found in {self.__class__.__name__}")
if hasattr(func, '_component_service') and func._component_service:
if callable(func):
method_signature = inspect.signature(func)
method_params = method_signature.parameters
call_args = []
for param_name, param_info in method_params.items():
if param_name == 'self' or param_info.kind == 4:
continue
param_type = param_info.annotation
if param_name in kwargs:
# 尝试将参数值转换为指定类型
try:
if param_type is inspect.Parameter.empty or param_type is Any:
# 如果参数类型未注明或为 Any 类型,则不进行强制类型转换
call_args.append(kwargs[param_name])
else:
# 否则,尝试将参数值转换为指定类型
call_args.append(param_type(kwargs[param_name]))
except (TypeError, ValueError) as e:
raise ValueError(f"Invalid value '{kwargs[param_name]}' for parameter '{param_name}': {e}")
elif param_info.default is inspect.Parameter.empty:
# 如果参数没有默认值且未在 kwargs 中指定,则抛出异常
raise ValueError(f"Missing required parameter '{param_name}' for method '{func_name}'.")
else:
# 否则,使用参数的默认值
call_args.append(param_info.default)
func_kwargs = {key: value for key, value in kwargs.items() if
key not in [param_name for param_name, _ in method_params.items()]}
if inspect.iscoroutinefunction(func):
return await func(*tuple(call_args), **func_kwargs)
else:
return await asyncio.to_thread(func, *tuple(call_args), **func_kwargs)
else:
raise RuntimeError(f"Function {func.__name__} not found in {self.__class__.__name__}")
else:
raise RuntimeError(f"Function {func.__name__} does not have the required annotation.")
def get_component(self, name: str) -> EdgeComponent:
component_instance = self.components.get(name)
if component_instance:
return component_instance
else:
raise ValueError(f"Component {name} not found")
async def handle_message(self, client, userdata, message):
self.logger.info(f"Receive: topic={message.topic} content={message.payload}")
# 'v1/devices/me/rpc/request/0'
topic = message.topic
if topic.startswith(MQTT_RPC_TOPIC_PREFIX):
rpc_id = topic.rsplit('/', 1)[-1]
payload = json.loads(message.payload)
response_topic = f"{MQTT_RPC_TOPIC_RESP_PREFIX}/{rpc_id}"
requestId = payload["requestId"]
response = {
"requestId": requestId,
"code": 0,
"message": "success",
}
reqeust_type = payload["type"]
if not reqeust_type in ['service', 'task']:
response['type'] = reqeust_type
response['code'] = -1
response['message'] = 'invalid reqeust type!'
else:
if reqeust_type == 'service':
method = payload.get('method')
component = payload.get('component')
method = camel_to_snake(method)
if payload.get("method") is None:
response = {"requestId": requestId, "type": reqeust_type, "code": -1,
"message": "没有method参数"}
else:
if component == "" or component is None:
executor = self
else:
executor = self.get_component(component)
if executor:
try:
kwargs = payload.get("params")
if kwargs is None:
kwargs = {}
result = await executor.execute(method, **kwargs)
response = {"requestId": requestId, "type": reqeust_type, "code": 0,
"message": "success", "result": result}
except Exception as e:
response = {"requestId": requestId, "type": reqeust_type, "code": -1, "message": str(e)}
else:
response = {"requestId": requestId, "type": reqeust_type, "code": -2,
"message": f"Executor {component} not found"}
else:
try:
kwargs = payload.get("params")
if kwargs is None:
kwargs = {}
result = await self.execute_task(**kwargs)
response = {"requestId": requestId, "type": reqeust_type, "code": 0, "message": "success",
"result": result}
except Exception as e:
response = {"requestId": requestId, "type": reqeust_type, "code": -1, "message": str(e)}
self.client.publish(response_topic, json.dumps(response))
def on_connect(self, client, userdata, flags, rc):
if rc == 0:
self.logger.info("Connected to MQTT Broker")
client.subscribe(MQTT_RPC_TOPIC)
else:
self.logger.info("Failed to connect, return code %d\n", rc)
def on_disconnect(self, client, userdata, rc):
self.logger.info("MQTT disconnect, return code %d\n", rc)
def on_message(self, client, userdata, message):
asyncio.run(self.handle_message(client, userdata, message))
def _stop_components(self):
for componentKey in self.components.keys():
component = self.components[componentKey]
component.configure(settings)
if component.__class__.component_auto_start:
try:
if component.is_started:
component.stop()
except Exception as e:
self.logger.error(f"{componentKey}停止失败", e)
def start(self):
# mqtt
mqtt_server = settings.get("mqtt.server", "127.0.0.1")
mqtt_port = settings.get("mqtt.port", 1883)
mqtt_username = settings.get("mqtt.username", "")
mqtt_password = settings.get("mqtt.password", "")
self.logger.info(f"mqtt: {mqtt_server}:{mqtt_port}:{mqtt_username}")
self.client.username_pw_set(mqtt_username, mqtt_password)
self.client.connect(mqtt_server, mqtt_port, 60)
# self.client.loop_start()
# 在单独的事件循环中启动 TCP 服务器和 FastAPI 服务器
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
tasks = []
# tcp server
tcp_enable = settings.get("tcp.enable", False)
if tcp_enable:
tasks.append(self.loop.create_task(self.start_tcp_server()))
# 启动 FastAPI 服务器
tasks.append(self.loop.create_task(self.server.serve()))
self.logger.info('edge context start!')
# 保持事件循环的持续运行
self.loop.run_forever()
def stop(self):
self._stop_components()
self.client.loop_stop()
self.client.disconnect()
# rest
# 停止 FastAPI 服务器
asyncio.run(self.stop_rest_server())
# tcp server
tcp_enable = settings.get("tcp.enable", False)
if tcp_enable:
asyncio.run(self.stop_tcp_server())
if self.loop:
self.loop.stop()
# 确保事件循环停止后再关闭
if not self.loop.is_running():
self.loop.close()
self.logger.info("Edge context stopped!")
async def handle_tcp(self, reader, writer):
addr = writer.get_extra_info('peername')
self.logger.info(f"Connected by {addr}")
self.tcp_clients.add(writer) # 添加客户端连接
try:
while True:
data = await reader.read(4096)
if not data:
break
requestId = None
try:
req = json.loads(data.decode('utf-8'))
self.logger.info(f"Received JSON request: {req}")
requestId = req["requestId"]
response = {
"requestId": requestId,
"code": 0,
"message": "success",
}
reqeust_type = req["type"]
if not reqeust_type in ['service', 'task']:
response['type'] = reqeust_type
response['code'] = -1
response['message'] = 'invalid reqeust type!'
else:
if reqeust_type == 'service':
method = req.get('method')
component = req.get('component')
method = camel_to_snake(method)
if req.get("method") is None:
response = {"requestId": requestId, "type": reqeust_type, "code": -1,
"message": "没有method参数"}
else:
if component == "" or component is None:
executor = self
else:
executor = self.get_component(component)
if executor:
try:
kwargs = req.get("params")
if kwargs is None:
kwargs = {}
result = await executor.execute(method, **kwargs)
response = {"requestId": requestId, "type": reqeust_type, "code": 0,
"message": "success", "result": result}
except Exception as e:
response = {"requestId": requestId, "type": reqeust_type, "code": -1,
"message": str(e)}
else:
response = {"requestId": requestId, "type": reqeust_type, "code": -2,
"message": f"Executor {component} not found"}
else:
try:
kwargs = req.get("params")
if kwargs is None:
kwargs = {}
result = await self.execute_task(**kwargs)
response = {"requestId": requestId, "type": reqeust_type, "code": 0,
"message": "success", "result": result}
except Exception as e:
response = {"requestId": requestId, "type": reqeust_type, "code": -1, "message": str(e)}
writer.write(json.dumps(response).encode('utf-8'))
await writer.drain()
except json.JSONDecodeError:
self.logger.error("Received invalid JSON")
error_response = json.dumps({"requestId": requestId, "code": -9, "message": "Invalid JSON"})
writer.write(error_response.encode('utf-8'))
await writer.drain()
break
except asyncio.CancelledError:
self.logger.error(f"Connection with {addr} cancelled.")
finally:
# 移除客户端连接
self.tcp_clients.remove(writer)
self.logger.info(f"Closing connection with {addr}")
writer.close()
await writer.wait_closed()
async def start_tcp_server(self):
# Start the server
tcp_port = settings.get("tcp.port", 13000)
self.tcp_server = await asyncio.start_server(self.handle_tcp, '0.0.0.0', tcp_port)
addr = self.tcp_server.sockets[0].getsockname()
self.logger.info(f"Tcp Serving on {addr}")
# Create a future to keep the server running
async with self.tcp_server:
await self.tcp_server.serve_forever()
async def broadcast(self, message):
# 向所有连接的客户端发送消息
for client in self.tcp_clients:
client.write(message.encode())
await client.drain()
async def stop_tcp_server(self):
self.logger.info("Stopping tcp server...")
self.tcp_server.close() # Stop accepting new connections
await self.tcp_server.wait_closed() # Wait for the server to close existing connections
for client in self.tcp_clients:
client.close()
await client.wait_closed()
self.tcp_clients.clear()
self.logger.info("Tcp Server stopped.")
async def stop_rest_server(self):
self.logger.info("Stopping fastapi server...")
self.server.should_exit = True
# await self.server.shutdown()
await asyncio.sleep(5)
self.logger.info("Fastapi Server stopped.")
@service()
async def execute_task(self, **kwargs):
"""
执行任务
:return:
"""
task = EdgeTask(**kwargs)
async def execute_step(s: EdgeTaskStep):
try:
if s.component is None or s.component == "":
component = self
else:
component = self.get_component(s.component)
res = await component.execute(s.method, **s.params)
return {"status": "success", "result": res}
except Exception as e:
return {"status": "error", "error": str(e)}
async def execute_steps(steps: List[EdgeTaskStep]):
tasks = [execute_step(sub_step) for sub_step in steps]
return await asyncio.gather(*tasks)
self.logger.info(f"执行任务{task.name}({task.id})开始")
results = []
if len(task.steps) > 0:
for step in task.steps:
if isinstance(step, list):
# 如果是 List[EdgeTaskStep],并行执行这些步骤
result = await execute_steps(step)
else:
# 否则,顺序执行单个步骤
result = await execute_step(step)
results.append(result)
self.logger.info(f"执行任务{task.name}({task.id})结束")
return results
@service()
def list_components(self):
"""
列出当前的边缘组件
:return:
"""
return list(self.components.keys())
@service()
def list_component_services(self, comp=None):
"""
列出指定边缘组件的服务函数
:return:
"""
comp = self if comp is None else self.get_component(comp)
if comp is None:
raise RuntimeError(f"Component {comp} not found")
methods = inspect.getmembers(comp.__class__, predicate=inspect.isfunction)
services = []
for name, method in methods:
if hasattr(method, '_component_service') and method._component_service:
services.append(name)
return services