294 lines
14 KiB
C
294 lines
14 KiB
C
![]() |
/*
|
||
|
Copyright (c), 2001-2022, Shenshu Tech. Co., Ltd.
|
||
|
*/
|
||
|
#include "sample_common_mau.h"
|
||
|
#include "libapi_common_svp.h"
|
||
|
#include "ss_mpi_sys.h"
|
||
|
|
||
|
static td_bool g_mpi_init = TD_FALSE;
|
||
|
|
||
|
static td_s32 sample_svp_mau_mpi_init(td_void)
|
||
|
{
|
||
|
td_s32 ret;
|
||
|
|
||
|
ret = ss_mpi_sys_exit();
|
||
|
macro_svp_check_exps_return(ret != TD_SUCCESS, ret, ENUM_SVP_ERR_LEVEL_ERROR,
|
||
|
"Error(%#x):ss_mpi_sys_exit failed!\n", ret);
|
||
|
ret = ss_mpi_sys_init();
|
||
|
macro_svp_check_exps_return(ret != TD_SUCCESS, ret, ENUM_SVP_ERR_LEVEL_ERROR,
|
||
|
"Error(%#x):ss_mpi_sys_init failed!\n", ret);
|
||
|
|
||
|
return TD_SUCCESS;
|
||
|
}
|
||
|
|
||
|
td_s32 sample_common_svp_mau_check_mau_mpi_init(td_void)
|
||
|
{
|
||
|
if (g_mpi_init == TD_FALSE) {
|
||
|
if (sample_svp_mau_mpi_init() != TD_SUCCESS) {
|
||
|
macro_svp_trace_err("mau mpi init failed!\n");
|
||
|
return TD_FALSE;
|
||
|
}
|
||
|
g_mpi_init = TD_TRUE;
|
||
|
}
|
||
|
return TD_TRUE;
|
||
|
}
|
||
|
td_s32 sample_common_svp_mau_mpi_exit(td_void)
|
||
|
{
|
||
|
g_mpi_init = TD_FALSE;
|
||
|
if (ss_mpi_sys_exit() != TD_SUCCESS) {
|
||
|
macro_svp_trace_err("Sys exit failed!\n");
|
||
|
return TD_FAILURE;
|
||
|
}
|
||
|
return TD_SUCCESS;
|
||
|
}
|
||
|
|
||
|
static td_s32 sample_svp_mau_check_src_idx_blob_info(const ot_svp_blob *matrix_blob)
|
||
|
{
|
||
|
td_u32 stride;
|
||
|
td_u32 byte_num = (td_u32)sizeof(td_u32);
|
||
|
|
||
|
macro_svp_check_exps_return(matrix_blob->virt_addr == 0, OT_ERR_SVP_MAU_ILLEGAL_PARAM,
|
||
|
ENUM_SVP_ERR_LEVEL_ERROR, "Error, blob->virt_addr can't be zero!\n");
|
||
|
macro_svp_check_exps_return(matrix_blob->phys_addr == 0, OT_ERR_SVP_MAU_ILLEGAL_PARAM,
|
||
|
ENUM_SVP_ERR_LEVEL_ERROR, "Error, blob->phys_addr can't be zero!\n");
|
||
|
/* check phys_addr 4 byte aligned */
|
||
|
macro_svp_check_exps_return(sample_svp_mau_check_align(matrix_blob->phys_addr, byte_num) != 0,
|
||
|
OT_ERR_SVP_MAU_ILLEGAL_PARAM, ENUM_SVP_ERR_LEVEL_ERROR,
|
||
|
"Error(%#x),blob->phys_addr(%llu) should be %u bytes aligned\n",
|
||
|
OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->phys_addr, byte_num);
|
||
|
/* check type */
|
||
|
macro_svp_check_exps_return(matrix_blob->type != OT_SVP_BLOB_TYPE_U32, OT_ERR_SVP_MAU_ILLEGAL_PARAM,
|
||
|
ENUM_SVP_ERR_LEVEL_ERROR, "Error(%#x),blob->type(%d) must be %d\n", OT_ERR_SVP_MAU_ILLEGAL_PARAM,
|
||
|
matrix_blob->type, OT_SVP_BLOB_TYPE_U32);
|
||
|
/* check num */
|
||
|
macro_svp_check_exps_return(matrix_blob->num != 1, OT_ERR_SVP_MAU_ILLEGAL_PARAM, ENUM_SVP_ERR_LEVEL_ERROR,
|
||
|
"Error(%#x),blob->num(%u) must be 1\n", OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->num);
|
||
|
/* check chn */
|
||
|
macro_svp_check_exps_return(matrix_blob->shape.whc.chn != 1, OT_ERR_SVP_MAU_ILLEGAL_PARAM,
|
||
|
ENUM_SVP_ERR_LEVEL_ERROR, "Error(%#x),blob->shape.whc.chn(%u) must be 1\n",
|
||
|
OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->shape.whc.chn);
|
||
|
|
||
|
macro_svp_check_exps_return((matrix_blob->shape.whc.width == 0) ||
|
||
|
(matrix_blob->shape.whc.width > SAMPLE_SVP_MAU_MATRIX_MAX_HEIGHT), OT_ERR_SVP_MAU_ILLEGAL_PARAM,
|
||
|
ENUM_SVP_ERR_LEVEL_ERROR, "Error(%#x), blob->shape.whc.width(%u) must be (0, %u]\n",
|
||
|
OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->shape.whc.width, SAMPLE_SVP_MAU_MATRIX_MAX_HEIGHT);
|
||
|
|
||
|
macro_svp_check_exps_return(matrix_blob->shape.whc.height != 1, OT_ERR_SVP_MAU_ILLEGAL_PARAM,
|
||
|
ENUM_SVP_ERR_LEVEL_ERROR, "Error(%#x), blob->shape.whc.height(%u) must be 1\n",
|
||
|
OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->shape.whc.height);
|
||
|
|
||
|
/* check stride 16 bytes aligned */
|
||
|
macro_svp_check_exps_return(sample_svp_mau_check_align(matrix_blob->stride, SAMPLE_SVP_MAU_ALIGN_16) != 0,
|
||
|
OT_ERR_SVP_MAU_ILLEGAL_PARAM, ENUM_SVP_ERR_LEVEL_ERROR,
|
||
|
"Error(%#x),blob->stride(%u) should be %u bytes aligned\n",
|
||
|
OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->stride, SAMPLE_SVP_MAU_ALIGN_16);
|
||
|
stride = sample_common_svp_align(matrix_blob->shape.whc.width * byte_num, SAMPLE_SVP_MAU_ALIGN_16);
|
||
|
macro_svp_check_exps_return((matrix_blob->stride == 0) ||
|
||
|
(matrix_blob->stride < stride), OT_ERR_SVP_MAU_ILLEGAL_PARAM, ENUM_SVP_ERR_LEVEL_ERROR,
|
||
|
"Error(%#x), blob->stride(%u) can't be 0 and should be equal to or greater than %u\n",
|
||
|
OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->stride, stride);
|
||
|
|
||
|
return TD_SUCCESS;
|
||
|
}
|
||
|
|
||
|
static td_s32 sample_svp_mau_set_idx_info_check_param(const ot_svp_mau_src_double_matrix *src_idx,
|
||
|
const ot_svp_mau_ctrl *ctrl, td_u32 left_matrix_height, td_u32 right_matrix_height)
|
||
|
{
|
||
|
td_s32 ret;
|
||
|
|
||
|
macro_svp_check_exps_return(ctrl == TD_NULL, OT_ERR_SVP_MAU_NULL_PTR,
|
||
|
ENUM_SVP_ERR_LEVEL_ERROR, "Error, ctrl is TD_NULL!\n");
|
||
|
macro_svp_check_exps_return(src_idx == TD_NULL, OT_ERR_SVP_MAU_NULL_PTR,
|
||
|
ENUM_SVP_ERR_LEVEL_ERROR, "Error, src_idx is TD_NULL!\n");
|
||
|
|
||
|
macro_svp_check_exps_return((ctrl->has_left_idx != TD_FALSE) && (ctrl->has_left_idx != TD_TRUE),
|
||
|
OT_ERR_SVP_MAU_ILLEGAL_PARAM, ENUM_SVP_ERR_LEVEL_ERROR,
|
||
|
"Error(%#x),ctrl->has_left_idx(%d) must be [%d, %d]\n",
|
||
|
OT_ERR_SVP_MAU_ILLEGAL_PARAM, ctrl->has_left_idx, TD_FALSE, TD_TRUE);
|
||
|
macro_svp_check_exps_return((ctrl->has_right_idx != TD_FALSE) && (ctrl->has_right_idx != TD_TRUE),
|
||
|
OT_ERR_SVP_MAU_ILLEGAL_PARAM, ENUM_SVP_ERR_LEVEL_ERROR,
|
||
|
"Error(%#x),ctrl->has_right_idx(%d) must be [%d, %d]\n",
|
||
|
OT_ERR_SVP_MAU_ILLEGAL_PARAM, ctrl->has_right_idx, TD_FALSE, TD_TRUE);
|
||
|
|
||
|
macro_svp_check_exps_return((left_matrix_height == 0) ||
|
||
|
(left_matrix_height > SAMPLE_SVP_MAU_MATRIX_MAX_HEIGHT), OT_ERR_SVP_MAU_ILLEGAL_PARAM,
|
||
|
ENUM_SVP_ERR_LEVEL_ERROR, "left_matrix_height(%u) must be (0, %u]\n",
|
||
|
left_matrix_height, SAMPLE_SVP_MAU_MATRIX_MAX_HEIGHT);
|
||
|
macro_svp_check_exps_return((right_matrix_height == 0) ||
|
||
|
(right_matrix_height > SAMPLE_SVP_MAU_MATRIX_MAX_HEIGHT), OT_ERR_SVP_MAU_ILLEGAL_PARAM,
|
||
|
ENUM_SVP_ERR_LEVEL_ERROR, "right_matrix_height(%u) must be (0, %u]\n",
|
||
|
right_matrix_height, SAMPLE_SVP_MAU_MATRIX_MAX_HEIGHT);
|
||
|
|
||
|
if (ctrl->has_left_idx == TD_TRUE) {
|
||
|
ret = sample_svp_mau_check_src_idx_blob_info(&src_idx->left_matrix);
|
||
|
macro_svp_check_exps_return(ret != TD_SUCCESS, ret, ENUM_SVP_ERR_LEVEL_ERROR,
|
||
|
"Error(%#x), sample_svp_mau_check_src_idx_left_matrix fail\n", ret);
|
||
|
}
|
||
|
if (ctrl->has_right_idx == TD_TRUE) {
|
||
|
ret = sample_svp_mau_check_src_idx_blob_info(&src_idx->right_matrix);
|
||
|
macro_svp_check_exps_return(ret != TD_SUCCESS, ret, ENUM_SVP_ERR_LEVEL_ERROR,
|
||
|
"Error(%#x), sample_svp_mau_check_src_idx_right_matrix fail\n", ret);
|
||
|
}
|
||
|
|
||
|
return TD_SUCCESS;
|
||
|
}
|
||
|
td_s32 sample_svp_mau_set_idx_info(ot_svp_mau_src_double_matrix *src_idx,
|
||
|
const ot_svp_mau_ctrl *ctrl, const ot_svp_mau_src_double_matrix *src)
|
||
|
{
|
||
|
td_u32 i;
|
||
|
td_u32 *idx = TD_NULL;
|
||
|
td_s32 ret;
|
||
|
td_u32 tmp;
|
||
|
td_u32 left_matrix_height, right_matrix_height;
|
||
|
|
||
|
macro_svp_check_exps_return(src == TD_NULL, OT_ERR_SVP_MAU_NULL_PTR,
|
||
|
ENUM_SVP_ERR_LEVEL_ERROR, "Error, src is TD_NULL!\n");
|
||
|
|
||
|
left_matrix_height = src->left_matrix.shape.whc.height;
|
||
|
right_matrix_height = src->right_matrix.shape.whc.height;
|
||
|
ret = sample_svp_mau_set_idx_info_check_param(src_idx, ctrl, left_matrix_height, right_matrix_height);
|
||
|
macro_svp_check_exps_return(ret != TD_SUCCESS, ret, ENUM_SVP_ERR_LEVEL_ERROR,
|
||
|
"Error(%#x), sample_svp_mau_set_idx_info_check_param fail\n", ret);
|
||
|
|
||
|
/* set left matrix idx */
|
||
|
if (ctrl->has_left_idx == TD_TRUE) {
|
||
|
idx = macro_svp_convert_addr_to_ptr(td_u32, src_idx->left_matrix.virt_addr);
|
||
|
for (i = 0; i < src_idx->left_matrix.shape.whc.width; i++) {
|
||
|
tmp = i * SAMPLE_SVP_MAU_GENERATE_IDX_INTERVAL;
|
||
|
if (tmp < left_matrix_height) {
|
||
|
*(idx + i) = tmp;
|
||
|
} else {
|
||
|
*(idx + i) = left_matrix_height - 1;
|
||
|
}
|
||
|
}
|
||
|
ret = sample_common_svp_flush_cache(src_idx->left_matrix.phys_addr,
|
||
|
macro_svp_convert_addr_to_ptr(td_void, src_idx->left_matrix.virt_addr),
|
||
|
src_idx->left_matrix.shape.whc.height * src_idx->left_matrix.stride);
|
||
|
macro_svp_check_exps_return(ret != TD_SUCCESS, ret,
|
||
|
ENUM_SVP_ERR_LEVEL_ERROR, "Error,flush cache failed!\n");
|
||
|
}
|
||
|
|
||
|
/* set right matrix idx */
|
||
|
if (ctrl->has_right_idx == TD_TRUE) {
|
||
|
idx = macro_svp_convert_addr_to_ptr(td_u32, src_idx->right_matrix.virt_addr);
|
||
|
for (i = 0; i < src_idx->right_matrix.shape.whc.width; i++) {
|
||
|
tmp = i * SAMPLE_SVP_MAU_GENERATE_IDX_INTERVAL;
|
||
|
if (tmp < right_matrix_height) {
|
||
|
*(idx + i) = tmp;
|
||
|
} else {
|
||
|
*(idx + i) = right_matrix_height - 1;
|
||
|
}
|
||
|
}
|
||
|
ret = sample_common_svp_flush_cache(src_idx->right_matrix.phys_addr,
|
||
|
macro_svp_convert_addr_to_ptr(td_void, src_idx->right_matrix.virt_addr),
|
||
|
src_idx->right_matrix.shape.whc.height * src_idx->right_matrix.stride);
|
||
|
macro_svp_check_exps_return(ret != TD_SUCCESS, ret,
|
||
|
ENUM_SVP_ERR_LEVEL_ERROR, "Error,flush cache failed!\n");
|
||
|
}
|
||
|
return TD_SUCCESS;
|
||
|
}
|
||
|
|
||
|
static td_s32 sample_svp_mau_check_src_blob_info(const ot_svp_blob *matrix_blob)
|
||
|
{
|
||
|
td_u32 stride;
|
||
|
td_u32 byte_num = (td_u32)sizeof(td_u32);
|
||
|
|
||
|
macro_svp_check_exps_return(matrix_blob->virt_addr == 0, OT_ERR_SVP_MAU_ILLEGAL_PARAM,
|
||
|
ENUM_SVP_ERR_LEVEL_ERROR, "Error, blob->virt_addr can't be zero!\n");
|
||
|
macro_svp_check_exps_return(matrix_blob->phys_addr == 0, OT_ERR_SVP_MAU_ILLEGAL_PARAM,
|
||
|
ENUM_SVP_ERR_LEVEL_ERROR, "Error, blob->phys_addr can't be zero!\n");
|
||
|
/* check phys_addr 4 byte aligned */
|
||
|
macro_svp_check_exps_return(sample_svp_mau_check_align(matrix_blob->phys_addr, byte_num) != 0,
|
||
|
OT_ERR_SVP_MAU_ILLEGAL_PARAM, ENUM_SVP_ERR_LEVEL_ERROR,
|
||
|
"Error(%#x),blob->phys_addr(%llu) should be %u bytes aligned\n",
|
||
|
OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->phys_addr, byte_num);
|
||
|
/* check type */
|
||
|
macro_svp_check_exps_return(matrix_blob->type != OT_SVP_BLOB_TYPE_FP32, OT_ERR_SVP_MAU_ILLEGAL_PARAM,
|
||
|
ENUM_SVP_ERR_LEVEL_ERROR, "Error(%#x),blob->type(%d) must be %d\n", OT_ERR_SVP_MAU_ILLEGAL_PARAM,
|
||
|
matrix_blob->type, OT_SVP_BLOB_TYPE_FP32);
|
||
|
/* check num */
|
||
|
macro_svp_check_exps_return(matrix_blob->num != 1, OT_ERR_SVP_MAU_ILLEGAL_PARAM, ENUM_SVP_ERR_LEVEL_ERROR,
|
||
|
"Error(%#x),blob->num(%u) must be 1\n", OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->num);
|
||
|
/* check chn */
|
||
|
macro_svp_check_exps_return(matrix_blob->shape.whc.chn != 1, OT_ERR_SVP_MAU_ILLEGAL_PARAM,
|
||
|
ENUM_SVP_ERR_LEVEL_ERROR, "Error(%#x),blob->shape.whc.chn(%u) must be 1\n",
|
||
|
OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->shape.whc.chn);
|
||
|
|
||
|
macro_svp_check_exps_return((matrix_blob->shape.whc.width == 0) ||
|
||
|
(matrix_blob->shape.whc.width > SAMPLE_SVP_MAU_MATRIX_MAX_FP32_WIDTH), OT_ERR_SVP_MAU_ILLEGAL_PARAM,
|
||
|
ENUM_SVP_ERR_LEVEL_ERROR, "Error(%#x),matrix_blob->shape.whc.width(%u) must be (0, %u]\n",
|
||
|
OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->shape.whc.width, SAMPLE_SVP_MAU_MATRIX_MAX_FP32_WIDTH);
|
||
|
macro_svp_check_exps_return((matrix_blob->shape.whc.height == 0) ||
|
||
|
(matrix_blob->shape.whc.height > SAMPLE_SVP_MAU_MATRIX_MAX_HEIGHT), OT_ERR_SVP_MAU_ILLEGAL_PARAM,
|
||
|
ENUM_SVP_ERR_LEVEL_ERROR, "Error,matrix_blob->shape.whcheight(%u) must be (0, %u]\n",
|
||
|
matrix_blob->shape.whc.height, SAMPLE_SVP_MAU_MATRIX_MAX_HEIGHT);
|
||
|
|
||
|
/* check stride 16 bytes aligned */
|
||
|
macro_svp_check_exps_return(sample_svp_mau_check_align(matrix_blob->stride, SAMPLE_SVP_MAU_ALIGN_16) != 0,
|
||
|
OT_ERR_SVP_MAU_ILLEGAL_PARAM, ENUM_SVP_ERR_LEVEL_ERROR,
|
||
|
"Error(%#x),blob->stride(%u) should be %u bytes aligned\n",
|
||
|
OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->stride, SAMPLE_SVP_MAU_ALIGN_16);
|
||
|
stride = sample_common_svp_align(matrix_blob->shape.whc.width * byte_num, SAMPLE_SVP_MAU_ALIGN_16);
|
||
|
macro_svp_check_exps_return((matrix_blob->stride == 0) ||
|
||
|
(matrix_blob->stride < stride), OT_ERR_SVP_MAU_ILLEGAL_PARAM, ENUM_SVP_ERR_LEVEL_ERROR,
|
||
|
"Error(%#x), blob->stride(%u) can't be 0 and should be equal to or greater than %u\n",
|
||
|
OT_ERR_SVP_MAU_ILLEGAL_PARAM, matrix_blob->stride, stride);
|
||
|
|
||
|
return TD_SUCCESS;
|
||
|
}
|
||
|
|
||
|
td_s32 sample_svp_mau_generate_matrix_data(const ot_svp_mau_src_double_matrix *src)
|
||
|
{
|
||
|
td_s32 ret;
|
||
|
td_u32 i, j;
|
||
|
td_float *matrix_data = TD_NULL;
|
||
|
td_u32 stride, size;
|
||
|
td_float tmp_data;
|
||
|
|
||
|
/* check */
|
||
|
macro_svp_check_exps_return(src == TD_NULL, OT_ERR_SVP_MAU_NULL_PTR,
|
||
|
ENUM_SVP_ERR_LEVEL_ERROR, "Error, src is TD_NULL!\n");
|
||
|
ret = sample_svp_mau_check_src_blob_info(&src->left_matrix);
|
||
|
macro_svp_check_exps_return(ret != TD_SUCCESS, ret, ENUM_SVP_ERR_LEVEL_ERROR,
|
||
|
"Error(%#x), check param of src->left_matrix fail\n", ret);
|
||
|
ret = sample_svp_mau_check_src_blob_info(&src->right_matrix);
|
||
|
macro_svp_check_exps_return(ret != TD_SUCCESS, ret, ENUM_SVP_ERR_LEVEL_ERROR,
|
||
|
"Error(%#x), check param of src->right_matrix fail\n", ret);
|
||
|
|
||
|
/* left matrix */
|
||
|
matrix_data = macro_svp_convert_addr_to_ptr(td_float, src->left_matrix.virt_addr);
|
||
|
stride = src->left_matrix.stride / sizeof(td_u32); /* fp32 */
|
||
|
size = src->left_matrix.stride * src->left_matrix.shape.whc.height;
|
||
|
for (i = 0; i < src->left_matrix.shape.whc.height; i++) {
|
||
|
tmp_data = (td_float)(i + 1);
|
||
|
for (j = 0; j < src->left_matrix.shape.whc.width; j++) {
|
||
|
matrix_data[j] = tmp_data;
|
||
|
tmp_data++;
|
||
|
}
|
||
|
matrix_data += stride;
|
||
|
}
|
||
|
ret = sample_common_svp_flush_cache(src->left_matrix.phys_addr,
|
||
|
macro_svp_convert_addr_to_ptr(td_void, src->left_matrix.virt_addr), size);
|
||
|
macro_svp_check_exps_return(ret != TD_SUCCESS, ret,
|
||
|
ENUM_SVP_ERR_LEVEL_ERROR, "Error, flush cache failed!\n");
|
||
|
|
||
|
/* right matrix */
|
||
|
matrix_data = macro_svp_convert_addr_to_ptr(td_float, src->right_matrix.virt_addr);
|
||
|
stride = src->right_matrix.stride / sizeof(td_u32); /* fp32 */
|
||
|
size = src->right_matrix.stride * src->right_matrix.shape.whc.height;
|
||
|
for (i = 0; i < src->right_matrix.shape.whc.height; i++) {
|
||
|
tmp_data = (td_float)(i + 1);
|
||
|
for (j = 0; j < src->right_matrix.shape.whc.width; j++) {
|
||
|
matrix_data[j] = tmp_data;
|
||
|
tmp_data++;
|
||
|
}
|
||
|
matrix_data += stride;
|
||
|
}
|
||
|
ret = sample_common_svp_flush_cache(src->right_matrix.phys_addr,
|
||
|
macro_svp_convert_addr_to_ptr(td_void, src->right_matrix.virt_addr), size);
|
||
|
macro_svp_check_exps_return(ret != TD_SUCCESS, ret,
|
||
|
ENUM_SVP_ERR_LEVEL_ERROR, "Error, flush cache failed!\n");
|
||
|
|
||
|
return TD_SUCCESS;
|
||
|
}
|