ss928_framework/svp/common/sample_common_mau.c

294 lines
14 KiB
C
Raw Permalink Normal View History

2024-12-16 13:31:45 +08:00
/*
Copyright (c), 2001-2022, Shenshu Tech. Co., Ltd.
*/
#include "sample_common_mau.h"
#include "libapi_common_svp.h"
#include "ss_mpi_sys.h"
static td_bool g_mpi_init = TD_FALSE;
static td_s32 sample_svp_mau_mpi_init(td_void)
{
td_s32 ret;
ret = ss_mpi_sys_exit();
macro_svp_check_exps_return(ret != TD_SUCCESS, ret, ENUM_SVP_ERR_LEVEL_ERROR,
"Error(%#x):ss_mpi_sys_exit failed!\n", ret);
ret = ss_mpi_sys_init();
macro_svp_check_exps_return(ret != TD_SUCCESS, ret, ENUM_SVP_ERR_LEVEL_ERROR,
"Error(%#x):ss_mpi_sys_init failed!\n", ret);
return TD_SUCCESS;
}
td_s32 sample_common_svp_mau_check_mau_mpi_init(td_void)
{
if (g_mpi_init == TD_FALSE) {
if (sample_svp_mau_mpi_init() != TD_SUCCESS) {
macro_svp_trace_err("mau mpi init failed!\n");
return TD_FALSE;
}
g_mpi_init = TD_TRUE;
}
return TD_TRUE;
}
td_s32 sample_common_svp_mau_mpi_exit(td_void)
{
g_mpi_init = TD_FALSE;
if (ss_mpi_sys_exit() != TD_SUCCESS) {
macro_svp_trace_err("Sys exit failed!\n");
return TD_FAILURE;
}
return TD_SUCCESS;
}
static td_s32 sample_svp_mau_check_src_idx_blob_info(const ot_svp_blob *matrix_blob)
{
td_u32 stride;
td_u32 byte_num = (td_u32)sizeof(td_u32);
macro_svp_check_exps_return(matrix_blob->virt_addr == 0, OT_ERR_SVP_MAU_ILLEGAL_PARAM,
ENUM_SVP_ERR_LEVEL_ERROR, "Error, blob->virt_addr can't be zero!\n");
macro_svp_check_exps_return(matrix_blob->phys_addr == 0, OT_ERR_SVP_MAU_ILLEGAL_PARAM,
ENUM_SVP_ERR_LEVEL_ERROR, "Error, blob->phys_addr can't be zero!\n");
/* check phys_addr 4 byte aligned */
macro_svp_check_exps_return(sample_svp_mau_check_align(matrix_blob->phys_addr, byte_num) != 0,
OT_ERR_SVP_MAU_ILLEGAL_PARAM, ENUM_SVP_ERR_LEVEL_ERROR,
"Error(%#x),blob->phys_addr(%llu) should be %u bytes aligned\n",
OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->phys_addr, byte_num);
/* check type */
macro_svp_check_exps_return(matrix_blob->type != OT_SVP_BLOB_TYPE_U32, OT_ERR_SVP_MAU_ILLEGAL_PARAM,
ENUM_SVP_ERR_LEVEL_ERROR, "Error(%#x),blob->type(%d) must be %d\n", OT_ERR_SVP_MAU_ILLEGAL_PARAM,
matrix_blob->type, OT_SVP_BLOB_TYPE_U32);
/* check num */
macro_svp_check_exps_return(matrix_blob->num != 1, OT_ERR_SVP_MAU_ILLEGAL_PARAM, ENUM_SVP_ERR_LEVEL_ERROR,
"Error(%#x),blob->num(%u) must be 1\n", OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->num);
/* check chn */
macro_svp_check_exps_return(matrix_blob->shape.whc.chn != 1, OT_ERR_SVP_MAU_ILLEGAL_PARAM,
ENUM_SVP_ERR_LEVEL_ERROR, "Error(%#x),blob->shape.whc.chn(%u) must be 1\n",
OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->shape.whc.chn);
macro_svp_check_exps_return((matrix_blob->shape.whc.width == 0) ||
(matrix_blob->shape.whc.width > SAMPLE_SVP_MAU_MATRIX_MAX_HEIGHT), OT_ERR_SVP_MAU_ILLEGAL_PARAM,
ENUM_SVP_ERR_LEVEL_ERROR, "Error(%#x), blob->shape.whc.width(%u) must be (0, %u]\n",
OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->shape.whc.width, SAMPLE_SVP_MAU_MATRIX_MAX_HEIGHT);
macro_svp_check_exps_return(matrix_blob->shape.whc.height != 1, OT_ERR_SVP_MAU_ILLEGAL_PARAM,
ENUM_SVP_ERR_LEVEL_ERROR, "Error(%#x), blob->shape.whc.height(%u) must be 1\n",
OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->shape.whc.height);
/* check stride 16 bytes aligned */
macro_svp_check_exps_return(sample_svp_mau_check_align(matrix_blob->stride, SAMPLE_SVP_MAU_ALIGN_16) != 0,
OT_ERR_SVP_MAU_ILLEGAL_PARAM, ENUM_SVP_ERR_LEVEL_ERROR,
"Error(%#x),blob->stride(%u) should be %u bytes aligned\n",
OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->stride, SAMPLE_SVP_MAU_ALIGN_16);
stride = sample_common_svp_align(matrix_blob->shape.whc.width * byte_num, SAMPLE_SVP_MAU_ALIGN_16);
macro_svp_check_exps_return((matrix_blob->stride == 0) ||
(matrix_blob->stride < stride), OT_ERR_SVP_MAU_ILLEGAL_PARAM, ENUM_SVP_ERR_LEVEL_ERROR,
"Error(%#x), blob->stride(%u) can't be 0 and should be equal to or greater than %u\n",
OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->stride, stride);
return TD_SUCCESS;
}
static td_s32 sample_svp_mau_set_idx_info_check_param(const ot_svp_mau_src_double_matrix *src_idx,
const ot_svp_mau_ctrl *ctrl, td_u32 left_matrix_height, td_u32 right_matrix_height)
{
td_s32 ret;
macro_svp_check_exps_return(ctrl == TD_NULL, OT_ERR_SVP_MAU_NULL_PTR,
ENUM_SVP_ERR_LEVEL_ERROR, "Error, ctrl is TD_NULL!\n");
macro_svp_check_exps_return(src_idx == TD_NULL, OT_ERR_SVP_MAU_NULL_PTR,
ENUM_SVP_ERR_LEVEL_ERROR, "Error, src_idx is TD_NULL!\n");
macro_svp_check_exps_return((ctrl->has_left_idx != TD_FALSE) && (ctrl->has_left_idx != TD_TRUE),
OT_ERR_SVP_MAU_ILLEGAL_PARAM, ENUM_SVP_ERR_LEVEL_ERROR,
"Error(%#x),ctrl->has_left_idx(%d) must be [%d, %d]\n",
OT_ERR_SVP_MAU_ILLEGAL_PARAM, ctrl->has_left_idx, TD_FALSE, TD_TRUE);
macro_svp_check_exps_return((ctrl->has_right_idx != TD_FALSE) && (ctrl->has_right_idx != TD_TRUE),
OT_ERR_SVP_MAU_ILLEGAL_PARAM, ENUM_SVP_ERR_LEVEL_ERROR,
"Error(%#x),ctrl->has_right_idx(%d) must be [%d, %d]\n",
OT_ERR_SVP_MAU_ILLEGAL_PARAM, ctrl->has_right_idx, TD_FALSE, TD_TRUE);
macro_svp_check_exps_return((left_matrix_height == 0) ||
(left_matrix_height > SAMPLE_SVP_MAU_MATRIX_MAX_HEIGHT), OT_ERR_SVP_MAU_ILLEGAL_PARAM,
ENUM_SVP_ERR_LEVEL_ERROR, "left_matrix_height(%u) must be (0, %u]\n",
left_matrix_height, SAMPLE_SVP_MAU_MATRIX_MAX_HEIGHT);
macro_svp_check_exps_return((right_matrix_height == 0) ||
(right_matrix_height > SAMPLE_SVP_MAU_MATRIX_MAX_HEIGHT), OT_ERR_SVP_MAU_ILLEGAL_PARAM,
ENUM_SVP_ERR_LEVEL_ERROR, "right_matrix_height(%u) must be (0, %u]\n",
right_matrix_height, SAMPLE_SVP_MAU_MATRIX_MAX_HEIGHT);
if (ctrl->has_left_idx == TD_TRUE) {
ret = sample_svp_mau_check_src_idx_blob_info(&src_idx->left_matrix);
macro_svp_check_exps_return(ret != TD_SUCCESS, ret, ENUM_SVP_ERR_LEVEL_ERROR,
"Error(%#x), sample_svp_mau_check_src_idx_left_matrix fail\n", ret);
}
if (ctrl->has_right_idx == TD_TRUE) {
ret = sample_svp_mau_check_src_idx_blob_info(&src_idx->right_matrix);
macro_svp_check_exps_return(ret != TD_SUCCESS, ret, ENUM_SVP_ERR_LEVEL_ERROR,
"Error(%#x), sample_svp_mau_check_src_idx_right_matrix fail\n", ret);
}
return TD_SUCCESS;
}
td_s32 sample_svp_mau_set_idx_info(ot_svp_mau_src_double_matrix *src_idx,
const ot_svp_mau_ctrl *ctrl, const ot_svp_mau_src_double_matrix *src)
{
td_u32 i;
td_u32 *idx = TD_NULL;
td_s32 ret;
td_u32 tmp;
td_u32 left_matrix_height, right_matrix_height;
macro_svp_check_exps_return(src == TD_NULL, OT_ERR_SVP_MAU_NULL_PTR,
ENUM_SVP_ERR_LEVEL_ERROR, "Error, src is TD_NULL!\n");
left_matrix_height = src->left_matrix.shape.whc.height;
right_matrix_height = src->right_matrix.shape.whc.height;
ret = sample_svp_mau_set_idx_info_check_param(src_idx, ctrl, left_matrix_height, right_matrix_height);
macro_svp_check_exps_return(ret != TD_SUCCESS, ret, ENUM_SVP_ERR_LEVEL_ERROR,
"Error(%#x), sample_svp_mau_set_idx_info_check_param fail\n", ret);
/* set left matrix idx */
if (ctrl->has_left_idx == TD_TRUE) {
idx = macro_svp_convert_addr_to_ptr(td_u32, src_idx->left_matrix.virt_addr);
for (i = 0; i < src_idx->left_matrix.shape.whc.width; i++) {
tmp = i * SAMPLE_SVP_MAU_GENERATE_IDX_INTERVAL;
if (tmp < left_matrix_height) {
*(idx + i) = tmp;
} else {
*(idx + i) = left_matrix_height - 1;
}
}
ret = sample_common_svp_flush_cache(src_idx->left_matrix.phys_addr,
macro_svp_convert_addr_to_ptr(td_void, src_idx->left_matrix.virt_addr),
src_idx->left_matrix.shape.whc.height * src_idx->left_matrix.stride);
macro_svp_check_exps_return(ret != TD_SUCCESS, ret,
ENUM_SVP_ERR_LEVEL_ERROR, "Error,flush cache failed!\n");
}
/* set right matrix idx */
if (ctrl->has_right_idx == TD_TRUE) {
idx = macro_svp_convert_addr_to_ptr(td_u32, src_idx->right_matrix.virt_addr);
for (i = 0; i < src_idx->right_matrix.shape.whc.width; i++) {
tmp = i * SAMPLE_SVP_MAU_GENERATE_IDX_INTERVAL;
if (tmp < right_matrix_height) {
*(idx + i) = tmp;
} else {
*(idx + i) = right_matrix_height - 1;
}
}
ret = sample_common_svp_flush_cache(src_idx->right_matrix.phys_addr,
macro_svp_convert_addr_to_ptr(td_void, src_idx->right_matrix.virt_addr),
src_idx->right_matrix.shape.whc.height * src_idx->right_matrix.stride);
macro_svp_check_exps_return(ret != TD_SUCCESS, ret,
ENUM_SVP_ERR_LEVEL_ERROR, "Error,flush cache failed!\n");
}
return TD_SUCCESS;
}
static td_s32 sample_svp_mau_check_src_blob_info(const ot_svp_blob *matrix_blob)
{
td_u32 stride;
td_u32 byte_num = (td_u32)sizeof(td_u32);
macro_svp_check_exps_return(matrix_blob->virt_addr == 0, OT_ERR_SVP_MAU_ILLEGAL_PARAM,
ENUM_SVP_ERR_LEVEL_ERROR, "Error, blob->virt_addr can't be zero!\n");
macro_svp_check_exps_return(matrix_blob->phys_addr == 0, OT_ERR_SVP_MAU_ILLEGAL_PARAM,
ENUM_SVP_ERR_LEVEL_ERROR, "Error, blob->phys_addr can't be zero!\n");
/* check phys_addr 4 byte aligned */
macro_svp_check_exps_return(sample_svp_mau_check_align(matrix_blob->phys_addr, byte_num) != 0,
OT_ERR_SVP_MAU_ILLEGAL_PARAM, ENUM_SVP_ERR_LEVEL_ERROR,
"Error(%#x),blob->phys_addr(%llu) should be %u bytes aligned\n",
OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->phys_addr, byte_num);
/* check type */
macro_svp_check_exps_return(matrix_blob->type != OT_SVP_BLOB_TYPE_FP32, OT_ERR_SVP_MAU_ILLEGAL_PARAM,
ENUM_SVP_ERR_LEVEL_ERROR, "Error(%#x),blob->type(%d) must be %d\n", OT_ERR_SVP_MAU_ILLEGAL_PARAM,
matrix_blob->type, OT_SVP_BLOB_TYPE_FP32);
/* check num */
macro_svp_check_exps_return(matrix_blob->num != 1, OT_ERR_SVP_MAU_ILLEGAL_PARAM, ENUM_SVP_ERR_LEVEL_ERROR,
"Error(%#x),blob->num(%u) must be 1\n", OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->num);
/* check chn */
macro_svp_check_exps_return(matrix_blob->shape.whc.chn != 1, OT_ERR_SVP_MAU_ILLEGAL_PARAM,
ENUM_SVP_ERR_LEVEL_ERROR, "Error(%#x),blob->shape.whc.chn(%u) must be 1\n",
OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->shape.whc.chn);
macro_svp_check_exps_return((matrix_blob->shape.whc.width == 0) ||
(matrix_blob->shape.whc.width > SAMPLE_SVP_MAU_MATRIX_MAX_FP32_WIDTH), OT_ERR_SVP_MAU_ILLEGAL_PARAM,
ENUM_SVP_ERR_LEVEL_ERROR, "Error(%#x),matrix_blob->shape.whc.width(%u) must be (0, %u]\n",
OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->shape.whc.width, SAMPLE_SVP_MAU_MATRIX_MAX_FP32_WIDTH);
macro_svp_check_exps_return((matrix_blob->shape.whc.height == 0) ||
(matrix_blob->shape.whc.height > SAMPLE_SVP_MAU_MATRIX_MAX_HEIGHT), OT_ERR_SVP_MAU_ILLEGAL_PARAM,
ENUM_SVP_ERR_LEVEL_ERROR, "Error,matrix_blob->shape.whcheight(%u) must be (0, %u]\n",
matrix_blob->shape.whc.height, SAMPLE_SVP_MAU_MATRIX_MAX_HEIGHT);
/* check stride 16 bytes aligned */
macro_svp_check_exps_return(sample_svp_mau_check_align(matrix_blob->stride, SAMPLE_SVP_MAU_ALIGN_16) != 0,
OT_ERR_SVP_MAU_ILLEGAL_PARAM, ENUM_SVP_ERR_LEVEL_ERROR,
"Error(%#x),blob->stride(%u) should be %u bytes aligned\n",
OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->stride, SAMPLE_SVP_MAU_ALIGN_16);
stride = sample_common_svp_align(matrix_blob->shape.whc.width * byte_num, SAMPLE_SVP_MAU_ALIGN_16);
macro_svp_check_exps_return((matrix_blob->stride == 0) ||
(matrix_blob->stride < stride), OT_ERR_SVP_MAU_ILLEGAL_PARAM, ENUM_SVP_ERR_LEVEL_ERROR,
"Error(%#x), blob->stride(%u) can't be 0 and should be equal to or greater than %u\n",
OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->stride, stride);
return TD_SUCCESS;
}
td_s32 sample_svp_mau_generate_matrix_data(const ot_svp_mau_src_double_matrix *src)
{
td_s32 ret;
td_u32 i, j;
td_float *matrix_data = TD_NULL;
td_u32 stride, size;
td_float tmp_data;
/* check */
macro_svp_check_exps_return(src == TD_NULL, OT_ERR_SVP_MAU_NULL_PTR,
ENUM_SVP_ERR_LEVEL_ERROR, "Error, src is TD_NULL!\n");
ret = sample_svp_mau_check_src_blob_info(&src->left_matrix);
macro_svp_check_exps_return(ret != TD_SUCCESS, ret, ENUM_SVP_ERR_LEVEL_ERROR,
"Error(%#x), check param of src->left_matrix fail\n", ret);
ret = sample_svp_mau_check_src_blob_info(&src->right_matrix);
macro_svp_check_exps_return(ret != TD_SUCCESS, ret, ENUM_SVP_ERR_LEVEL_ERROR,
"Error(%#x), check param of src->right_matrix fail\n", ret);
/* left matrix */
matrix_data = macro_svp_convert_addr_to_ptr(td_float, src->left_matrix.virt_addr);
stride = src->left_matrix.stride / sizeof(td_u32); /* fp32 */
size = src->left_matrix.stride * src->left_matrix.shape.whc.height;
for (i = 0; i < src->left_matrix.shape.whc.height; i++) {
tmp_data = (td_float)(i + 1);
for (j = 0; j < src->left_matrix.shape.whc.width; j++) {
matrix_data[j] = tmp_data;
tmp_data++;
}
matrix_data += stride;
}
ret = sample_common_svp_flush_cache(src->left_matrix.phys_addr,
macro_svp_convert_addr_to_ptr(td_void, src->left_matrix.virt_addr), size);
macro_svp_check_exps_return(ret != TD_SUCCESS, ret,
ENUM_SVP_ERR_LEVEL_ERROR, "Error, flush cache failed!\n");
/* right matrix */
matrix_data = macro_svp_convert_addr_to_ptr(td_float, src->right_matrix.virt_addr);
stride = src->right_matrix.stride / sizeof(td_u32); /* fp32 */
size = src->right_matrix.stride * src->right_matrix.shape.whc.height;
for (i = 0; i < src->right_matrix.shape.whc.height; i++) {
tmp_data = (td_float)(i + 1);
for (j = 0; j < src->right_matrix.shape.whc.width; j++) {
matrix_data[j] = tmp_data;
tmp_data++;
}
matrix_data += stride;
}
ret = sample_common_svp_flush_cache(src->right_matrix.phys_addr,
macro_svp_convert_addr_to_ptr(td_void, src->right_matrix.virt_addr), size);
macro_svp_check_exps_return(ret != TD_SUCCESS, ret,
ENUM_SVP_ERR_LEVEL_ERROR, "Error, flush cache failed!\n");
return TD_SUCCESS;
}