mirror of
http://git.xinwangdao.com/cnnc-embedded-parts-detect/detect-gui.git
synced 2025-06-24 13:14:11 +08:00
473 lines
21 KiB
Python
473 lines
21 KiB
Python
import asyncio
|
||
import importlib
|
||
import inspect
|
||
import json
|
||
import os
|
||
import pkgutil
|
||
import re
|
||
from typing import Type, Dict, Optional, Any, List
|
||
|
||
from fastapi import FastAPI, Request
|
||
from fastapi.staticfiles import StaticFiles
|
||
from httpx import delete
|
||
from starlette.responses import JSONResponse
|
||
from uvicorn import Server, Config
|
||
|
||
from core.edge_component import EdgeComponent, service
|
||
from core.config import settings
|
||
from core.edge_internal import Database, Scheduler
|
||
from core.edge_logger import LoggingMixin
|
||
|
||
import paho.mqtt.client as mqtt
|
||
|
||
from core.edge_routes import setup_routes
|
||
from core.edge_task import EdgeTask, EdgeTaskStep
|
||
from core.edge_util import camel_to_snake
|
||
|
||
MQTT_TELEMETRY_TOPIC = "v1/devices/me/telemetry"
|
||
MQTT_RPC_TOPIC_PREFIX = "v1/devices/me/rpc/request"
|
||
MQTT_RPC_TOPIC_RESP_PREFIX = "v1/devices/me/rpc/response"
|
||
MQTT_RPC_TOPIC = f"{MQTT_RPC_TOPIC_PREFIX}/+"
|
||
|
||
|
||
class EdgeContext(LoggingMixin):
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
# 初始化mqtt客户端
|
||
self.client = mqtt.Client()
|
||
self.client.on_connect = self.on_connect
|
||
self.client.on_disconnect = self.on_disconnect
|
||
self.client.on_message = self.on_message
|
||
# 搜索边缘组件并实例化
|
||
self.components: Dict[str, EdgeComponent] = self._load_components()
|
||
self._init_internal_components()
|
||
self._init_components()
|
||
self.logger.info(f"当前边缘组件:{','.join(self.components.keys())}")
|
||
# loop
|
||
self.loop = None
|
||
# tcp server
|
||
self.tcp_server = None
|
||
self.tcp_clients = set() # 保存客户端连接
|
||
# rest
|
||
rest_port = settings.get("rest.port", 8000)
|
||
self.app = FastAPI()
|
||
# 挂载静态文件目录
|
||
local_dir = settings.get('rest.static', 'local')
|
||
os.makedirs(local_dir, exist_ok=True)
|
||
self.app.mount("/local", StaticFiles(directory=local_dir), name="local")
|
||
self.config = Config(app=self.app, host="0.0.0.0", port=rest_port)
|
||
self.server = Server(self.config)
|
||
setup_routes(self, self.app)
|
||
|
||
def _load_components(self) -> Dict[str, EdgeComponent]:
|
||
package_name = 'components'
|
||
components = {}
|
||
# 动态导入 components 目录下的所有模块
|
||
package = importlib.import_module(package_name)
|
||
for _, module_name, _ in pkgutil.iter_modules(package.__path__, package_name + "."):
|
||
module = importlib.import_module(module_name)
|
||
# 查找模块中的所有 EdgeComponent 子类
|
||
for name, cls in inspect.getmembers(module, inspect.isclass):
|
||
if issubclass(cls, EdgeComponent) and cls is not EdgeComponent:
|
||
# 确保是非抽象类且没有子类
|
||
if not inspect.isabstract(cls) and not cls.__subclasses__() and hasattr(cls, 'component_name'):
|
||
instance = cls(self)
|
||
components[cls.component_name] = instance
|
||
|
||
return components
|
||
|
||
def _init_internal_components(self):
|
||
database = Database(self)
|
||
scheduler = Scheduler(self)
|
||
database.configure(settings)
|
||
scheduler.configure(settings)
|
||
database.start()
|
||
scheduler.start()
|
||
self.components[database.__class__.component_name] = database
|
||
self.components[scheduler.__class__.component_name] = scheduler
|
||
|
||
def _init_components(self):
|
||
for componentKey in self.components.keys():
|
||
if componentKey in [Database.component_name, Scheduler.component_name]:
|
||
continue
|
||
component = self.components[componentKey]
|
||
component.configure(settings)
|
||
if component.__class__.component_auto_start:
|
||
try:
|
||
if not component.is_started:
|
||
component.start()
|
||
except Exception as e:
|
||
self.logger.error(f"{componentKey}启动失败", e)
|
||
|
||
def _get_non_abstract_subclasses(self, cls: Type[EdgeComponent]) -> Dict[str, EdgeComponent]:
|
||
subclasses = set(cls.__subclasses__())
|
||
non_abstract_subclasses = {}
|
||
|
||
for subclass in subclasses:
|
||
if not inspect.isabstract(subclass) and not subclass.__subclasses__() and hasattr(subclass, 'component_name'):
|
||
instance = subclass()
|
||
instance.configure(settings)
|
||
if subclass.component_auto_start:
|
||
instance.start()
|
||
non_abstract_subclasses[subclass.component_name] = instance
|
||
else:
|
||
non_abstract_subclasses.update(self._get_non_abstract_subclasses(subclass))
|
||
|
||
return non_abstract_subclasses
|
||
|
||
async def execute(self, func_name, *args, **kwargs):
|
||
func = getattr(self, func_name, None)
|
||
if func is None:
|
||
raise RuntimeError(f"Function {func_name} not found in {self.__class__.__name__}")
|
||
|
||
if hasattr(func, '_component_service') and func._component_service:
|
||
if callable(func):
|
||
method_signature = inspect.signature(func)
|
||
method_params = method_signature.parameters
|
||
call_args = []
|
||
for param_name, param_info in method_params.items():
|
||
if param_name == 'self' or param_info.kind == 4:
|
||
continue
|
||
param_type = param_info.annotation
|
||
if param_name in kwargs:
|
||
# 尝试将参数值转换为指定类型
|
||
try:
|
||
if param_type is inspect.Parameter.empty or param_type is Any:
|
||
# 如果参数类型未注明或为 Any 类型,则不进行强制类型转换
|
||
call_args.append(kwargs[param_name])
|
||
else:
|
||
# 否则,尝试将参数值转换为指定类型
|
||
call_args.append(param_type(kwargs[param_name]))
|
||
except (TypeError, ValueError) as e:
|
||
raise ValueError(f"Invalid value '{kwargs[param_name]}' for parameter '{param_name}': {e}")
|
||
elif param_info.default is inspect.Parameter.empty:
|
||
# 如果参数没有默认值且未在 kwargs 中指定,则抛出异常
|
||
raise ValueError(f"Missing required parameter '{param_name}' for method '{func_name}'.")
|
||
else:
|
||
# 否则,使用参数的默认值
|
||
call_args.append(param_info.default)
|
||
|
||
func_kwargs = {key: value for key, value in kwargs.items() if
|
||
key not in [param_name for param_name, _ in method_params.items()]}
|
||
if inspect.iscoroutinefunction(func):
|
||
return await func(*tuple(call_args), **func_kwargs)
|
||
else:
|
||
return await asyncio.to_thread(func, *tuple(call_args), **func_kwargs)
|
||
else:
|
||
raise RuntimeError(f"Function {func.__name__} not found in {self.__class__.__name__}")
|
||
else:
|
||
raise RuntimeError(f"Function {func.__name__} does not have the required annotation.")
|
||
|
||
def get_component(self, name: str) -> EdgeComponent:
|
||
component_instance = self.components.get(name)
|
||
if component_instance:
|
||
return component_instance
|
||
else:
|
||
raise ValueError(f"Component {name} not found")
|
||
|
||
async def handle_message(self, client, userdata, message):
|
||
self.logger.info(f"Receive: topic={message.topic} content={message.payload}")
|
||
# 'v1/devices/me/rpc/request/0'
|
||
topic = message.topic
|
||
if topic.startswith(MQTT_RPC_TOPIC_PREFIX):
|
||
rpc_id = topic.rsplit('/', 1)[-1]
|
||
payload = json.loads(message.payload)
|
||
response_topic = f"{MQTT_RPC_TOPIC_RESP_PREFIX}/{rpc_id}"
|
||
|
||
requestId = payload["requestId"]
|
||
response = {
|
||
"requestId": requestId,
|
||
"code": 0,
|
||
"message": "success",
|
||
}
|
||
reqeust_type = payload["type"]
|
||
if not reqeust_type in ['service', 'task']:
|
||
response['type'] = reqeust_type
|
||
response['code'] = -1
|
||
response['message'] = 'invalid reqeust type!'
|
||
else:
|
||
if reqeust_type == 'service':
|
||
method = payload.get('method')
|
||
component = payload.get('component')
|
||
method = camel_to_snake(method)
|
||
if payload.get("method") is None:
|
||
response = {"requestId": requestId, "type": reqeust_type, "code": -1,
|
||
"message": "没有method参数!"}
|
||
else:
|
||
if component == "" or component is None:
|
||
executor = self
|
||
else:
|
||
executor = self.get_component(component)
|
||
if executor:
|
||
try:
|
||
kwargs = payload.get("params")
|
||
if kwargs is None:
|
||
kwargs = {}
|
||
result = await executor.execute(method, **kwargs)
|
||
response = {"requestId": requestId, "type": reqeust_type, "code": 0,
|
||
"message": "success", "result": result}
|
||
except Exception as e:
|
||
response = {"requestId": requestId, "type": reqeust_type, "code": -1, "message": str(e)}
|
||
else:
|
||
response = {"requestId": requestId, "type": reqeust_type, "code": -2,
|
||
"message": f"Executor {component} not found"}
|
||
else:
|
||
try:
|
||
kwargs = payload.get("params")
|
||
if kwargs is None:
|
||
kwargs = {}
|
||
result = await self.execute_task(**kwargs)
|
||
response = {"requestId": requestId, "type": reqeust_type, "code": 0, "message": "success",
|
||
"result": result}
|
||
except Exception as e:
|
||
response = {"requestId": requestId, "type": reqeust_type, "code": -1, "message": str(e)}
|
||
|
||
self.client.publish(response_topic, json.dumps(response))
|
||
|
||
def on_connect(self, client, userdata, flags, rc):
|
||
if rc == 0:
|
||
self.logger.info("Connected to MQTT Broker")
|
||
client.subscribe(MQTT_RPC_TOPIC)
|
||
else:
|
||
self.logger.info("Failed to connect, return code %d\n", rc)
|
||
|
||
def on_disconnect(self, client, userdata, rc):
|
||
self.logger.info("MQTT disconnect, return code %d\n", rc)
|
||
|
||
def on_message(self, client, userdata, message):
|
||
asyncio.run(self.handle_message(client, userdata, message))
|
||
|
||
def _stop_components(self):
|
||
for componentKey in self.components.keys():
|
||
component = self.components[componentKey]
|
||
component.configure(settings)
|
||
if component.__class__.component_auto_start:
|
||
try:
|
||
if component.is_started:
|
||
component.stop()
|
||
except Exception as e:
|
||
self.logger.error(f"{componentKey}停止失败", e)
|
||
|
||
def start(self):
|
||
# mqtt
|
||
mqtt_server = settings.get("mqtt.server", "127.0.0.1")
|
||
mqtt_port = settings.get("mqtt.port", 1883)
|
||
mqtt_username = settings.get("mqtt.username", "")
|
||
mqtt_password = settings.get("mqtt.password", "")
|
||
self.logger.info(f"mqtt: {mqtt_server}:{mqtt_port}:{mqtt_username}")
|
||
self.client.username_pw_set(mqtt_username, mqtt_password)
|
||
self.client.connect(mqtt_server, mqtt_port, 60)
|
||
# self.client.loop_start()
|
||
# 在单独的事件循环中启动 TCP 服务器和 FastAPI 服务器
|
||
self.loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(self.loop)
|
||
tasks = []
|
||
# tcp server
|
||
tcp_enable = settings.get("tcp.enable", False)
|
||
if tcp_enable:
|
||
tasks.append(self.loop.create_task(self.start_tcp_server()))
|
||
# 启动 FastAPI 服务器
|
||
tasks.append(self.loop.create_task(self.server.serve()))
|
||
self.logger.info('edge context start!')
|
||
# 保持事件循环的持续运行
|
||
self.loop.run_forever()
|
||
|
||
def stop(self):
|
||
self._stop_components()
|
||
|
||
self.client.loop_stop()
|
||
self.client.disconnect()
|
||
|
||
# rest
|
||
# 停止 FastAPI 服务器
|
||
asyncio.run(self.stop_rest_server())
|
||
# tcp server
|
||
tcp_enable = settings.get("tcp.enable", False)
|
||
if tcp_enable:
|
||
asyncio.run(self.stop_tcp_server())
|
||
|
||
if self.loop:
|
||
self.loop.stop()
|
||
# 确保事件循环停止后再关闭
|
||
if not self.loop.is_running():
|
||
self.loop.close()
|
||
|
||
self.logger.info("Edge context stopped!")
|
||
|
||
async def handle_tcp(self, reader, writer):
|
||
addr = writer.get_extra_info('peername')
|
||
self.logger.info(f"Connected by {addr}")
|
||
self.tcp_clients.add(writer) # 添加客户端连接
|
||
try:
|
||
while True:
|
||
data = await reader.read(8192)
|
||
if not data:
|
||
break
|
||
|
||
requestId = None
|
||
try:
|
||
req = json.loads(data.decode('utf-8'))
|
||
self.logger.info(f"Received JSON request: {req}")
|
||
requestId = req["requestId"]
|
||
response = {
|
||
"requestId": requestId,
|
||
"code": 0,
|
||
"message": "success",
|
||
}
|
||
reqeust_type = req["type"]
|
||
if not reqeust_type in ['service', 'task']:
|
||
response['type'] = reqeust_type
|
||
response['code'] = -1
|
||
response['message'] = 'invalid reqeust type!'
|
||
else:
|
||
if reqeust_type == 'service':
|
||
method = req.get('method')
|
||
component = req.get('component')
|
||
method = camel_to_snake(method)
|
||
if req.get("method") is None:
|
||
response = {"requestId": requestId, "type": reqeust_type, "code": -1,
|
||
"message": "没有method参数!"}
|
||
else:
|
||
if component == "" or component is None:
|
||
executor = self
|
||
else:
|
||
executor = self.get_component(component)
|
||
if executor:
|
||
try:
|
||
kwargs = req.get("params")
|
||
if kwargs is None:
|
||
kwargs = {}
|
||
result = await executor.execute(method, **kwargs)
|
||
response = {"requestId": requestId, "type": reqeust_type, "code": 0,
|
||
"message": "success", "result": result}
|
||
except Exception as e:
|
||
response = {"requestId": requestId, "type": reqeust_type, "code": -1,
|
||
"message": str(e)}
|
||
else:
|
||
response = {"requestId": requestId, "type": reqeust_type, "code": -2,
|
||
"message": f"Executor {component} not found"}
|
||
else:
|
||
try:
|
||
kwargs = req.get("params")
|
||
if kwargs is None:
|
||
kwargs = {}
|
||
result = await self.execute_task(**kwargs)
|
||
response = {"requestId": requestId, "type": reqeust_type, "code": 0,
|
||
"message": "success", "result": result}
|
||
except Exception as e:
|
||
response = {"requestId": requestId, "type": reqeust_type, "code": -1, "message": str(e)}
|
||
writer.write(json.dumps(response).encode('utf-8'))
|
||
await writer.drain()
|
||
except json.JSONDecodeError:
|
||
self.logger.error("Received invalid JSON")
|
||
error_response = json.dumps({"requestId": requestId, "code": -9, "message": "Invalid JSON"})
|
||
writer.write(error_response.encode('utf-8'))
|
||
await writer.drain()
|
||
break
|
||
except asyncio.CancelledError:
|
||
self.logger.error(f"Connection with {addr} cancelled.")
|
||
finally:
|
||
# 移除客户端连接
|
||
self.tcp_clients.remove(writer)
|
||
self.logger.info(f"Closing connection with {addr}")
|
||
writer.close()
|
||
await writer.wait_closed()
|
||
|
||
async def start_tcp_server(self):
|
||
# Start the server
|
||
tcp_port = settings.get("tcp.port", 13000)
|
||
self.tcp_server = await asyncio.start_server(self.handle_tcp, '0.0.0.0', tcp_port)
|
||
addr = self.tcp_server.sockets[0].getsockname()
|
||
self.logger.info(f"Tcp Serving on {addr}")
|
||
|
||
# Create a future to keep the server running
|
||
async with self.tcp_server:
|
||
await self.tcp_server.serve_forever()
|
||
|
||
async def broadcast(self, message):
|
||
# 向所有连接的客户端发送消息
|
||
for client in self.tcp_clients:
|
||
client.write(message.encode())
|
||
await client.drain()
|
||
|
||
async def stop_tcp_server(self):
|
||
self.logger.info("Stopping tcp server...")
|
||
self.tcp_server.close() # Stop accepting new connections
|
||
await self.tcp_server.wait_closed() # Wait for the server to close existing connections
|
||
for client in self.tcp_clients:
|
||
client.close()
|
||
await client.wait_closed()
|
||
self.tcp_clients.clear()
|
||
self.logger.info("Tcp Server stopped.")
|
||
|
||
async def stop_rest_server(self):
|
||
self.logger.info("Stopping fastapi server...")
|
||
self.server.should_exit = True
|
||
# await self.server.shutdown()
|
||
await asyncio.sleep(5)
|
||
self.logger.info("Fastapi Server stopped.")
|
||
|
||
@service()
|
||
async def execute_task(self, **kwargs):
|
||
"""
|
||
执行任务
|
||
:return:
|
||
"""
|
||
task = EdgeTask(**kwargs)
|
||
|
||
async def execute_step(s: EdgeTaskStep):
|
||
try:
|
||
if s.component is None or s.component == "":
|
||
component = self
|
||
else:
|
||
component = self.get_component(s.component)
|
||
res = await component.execute(s.method, **s.params)
|
||
return {"status": "success", "result": res}
|
||
except Exception as e:
|
||
return {"status": "error", "error": str(e)}
|
||
|
||
async def execute_steps(steps: List[EdgeTaskStep]):
|
||
tasks = [execute_step(sub_step) for sub_step in steps]
|
||
return await asyncio.gather(*tasks)
|
||
|
||
self.logger.info(f"执行任务{task.name}({task.id})开始")
|
||
results = []
|
||
if len(task.steps) > 0:
|
||
for step in task.steps:
|
||
if isinstance(step, list):
|
||
# 如果是 List[EdgeTaskStep],并行执行这些步骤
|
||
result = await execute_steps(step)
|
||
else:
|
||
# 否则,顺序执行单个步骤
|
||
result = await execute_step(step)
|
||
results.append(result)
|
||
|
||
self.logger.info(f"执行任务{task.name}({task.id})结束")
|
||
return results
|
||
|
||
@service()
|
||
def list_components(self):
|
||
"""
|
||
列出当前的边缘组件
|
||
:return:
|
||
"""
|
||
return list(self.components.keys())
|
||
|
||
@service()
|
||
def list_component_services(self, comp=None):
|
||
"""
|
||
列出指定边缘组件的服务函数
|
||
:return:
|
||
"""
|
||
comp = self if comp is None else self.get_component(comp)
|
||
if comp is None:
|
||
raise RuntimeError(f"Component {comp} not found")
|
||
|
||
methods = inspect.getmembers(comp.__class__, predicate=inspect.isfunction)
|
||
services = []
|
||
for name, method in methods:
|
||
if hasattr(method, '_component_service') and method._component_service:
|
||
services.append(name)
|
||
return services
|