aclnnDenseLightningIndexerGradKLLoss【免费下载链接】ops-transformer本项目是CANN提供的transformer类大模型算子库实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-transformer产品支持情况产品是否支持Ascend 950PR/Ascend 950DT×Atlas A3 训练系列产品/Atlas A3 推理系列产品√Atlas A2 训练系列产品/Atlas A2 推理系列产品√Atlas 200I/500 A2 推理产品×Atlas 推理系列产品×Atlas 训练系列产品×功能说明接口功能DenseLightningIndexerGradKlLoss算子是LightningIndexer的反向算子再额外融合了Loss计算功能。LightningIndexer算子将QueryToken和KeyToken之间的最高内在联系的TopK个筛选出来从而减少长序列场景下Attention的计算量加速长序列的网络的推理和训练的性能。稠密场景下的LightningIndexerGrad的输入query、key、query_index、key_index不用做稀疏化处理。计算公式Top-k value的计算公式$$ I_{t,:}W_{t,:}ReLU(\tilde{q}{t,:}\tilde{K}{:t,:}^\top) $$$W_{t,:}$是第$t$个token对应的$weights$$\tilde{q}_{t,:}$是$\tilde{q}$矩阵第$t$个token对应的$G$个query头合轴后的结果$\tilde{K}_{:t,:}$为$t$行$\tilde{K}$矩阵。正向的Softmax对应公式$$ p_{t,:} \text{Softmax}(q_{t,:} K_{:t,:}^\top/\sqrt{d}) $$$p_{t,:}$是第$t$个token对应的Softmax结果$q_{t,:}$是$q$矩阵第$t$个token对应的$G$个query头合轴后的结果${K}_{:t,:}$为$t$行$K$矩阵。npu_lightning_indexer会单独训练对应的loss function为$$ Loss{}\sum_tD_{KL}(p_{t,:}||Softmax(I_{t,:})) $$其中$p_{t,:}$是target distribution通过对main attention score 进行所有的head的求和然后把求和结果沿着上下文方向进行L1正则化得到。$D_{KL}$为KL散度其表达式为$$ D_{KL}(a||b){}\sum_ia_i\mathrm{log}{\left(\frac{a_i}{b_i}\right)} $$通过求导可得Loss的梯度表达式$$ dI\mathop{{}}\nolimits_{{t,:}}Softmax \left( I\mathop{{}}\nolimits_{{t,:}} \left) -p\mathop{{}}\nolimits_{{t,:}}\right. \right. $$利用链式法则可以进行weightsquery和key矩阵的梯度计算$$ dW\mathop{{}}\nolimits_{{t,:}}dI\mathop{{}}\nolimits_{{t,:}}\text{} \left( ReLU \left( S\mathop{{}}\nolimits_{{t,:}} \left) \left) \mathop{{}}\nolimits^{\top}\right. \right. \right. \right. $$$$ d\mathop{{\tilde{q}}}\nolimits_{{t,:}}dS\mathop{{}}\nolimits_{{t,:}}\tilde{K}\mathop{{}}\nolimits_{{:t,:}} $$$$ d\tilde{K}\mathop{{}}\nolimits_{{:t,:}}\left(dS\mathop{{}}\nolimits_{{t,:}} \left) \mathop{{}}\nolimits^{\top}\tilde{q}\mathop{{}}\nolimits_{{:t, :}}\right. \right. $$其中$S$为$\tilde{q}$和$K$矩阵乘的结果。函数原型算子执行接口为两段式接口必须先调用“aclnnDenseLightningIndexerGradKLLossGetWorkspaceSize”接口获取入参并根据计算流程计算所需workspace大小再调用“aclnnDenseLightningIndexerGradKLLoss”接口执行计算。aclnnStatus aclnnDenseLightningIndexerGradKLLossGetWorkspaceSize( const aclTensor *query, const aclTensor *key, const aclTensor *queryIndex, const aclTensor *keyIndex, const aclTensor *weights, const aclTensor *softmaxMax, const aclTensor *softmaxSum, const aclTensor *softmaxMaxIndex, const aclTensor *softmaxSumIndex, const aclTensor *queryRope, const aclTensor *keyRope, const aclIntArray *actualSeqLengthsQuery, const aclIntArray *actualSeqLengthsKey, double scaleValue, char *layout, int64_t sparseMode, int64_t preTokens, int64_t nextTokens, const aclTensor *dQueryIndex, const aclTensor *dKeyIndex, const aclTensor *dWeights, const aclTensor *loss, uint64_t *workspaceSize, aclOpExecutor *executor)aclnnStatus aclnnDenseLightningIndexerGradKLLoss( void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)aclnnDenseLightningIndexerGradKLLoss参数说明:参数名输入/输出描述使用说明数据类型数据格式维度(shape)非连续TensorqueryaclTensor*输入attention结构的输入QB: 支持泛化。S1: 支持泛化。N1: 支持128、64、32。D: 128。T1: 多个Batch的S1累加。FLOAT16、BFLOAT16ND(B,S1,N1,D);(T1,N1,D)×keyaclTensor*输入attention结构的输入KB: 支持泛化且与query的B保持一致。S2: 支持泛化。N2: 等于N1。D: 128。T2: 多个Batch的S2累加。FLOAT16、BFLOAT16ND(B,S2,N2,D);(T2,N2,D)×queryIndexaclTensor*输入lightningIndexer结构的输入queryIndex。B: 支持泛化且与query的B保持一致。S1: 支持泛化。Nidx1: 64、32、16、8。D: 128。T1: 多个Batch的S1累加。FLOAT16、BFLOAT16ND(B,S1,Nidx1,D);(T1,Nidx1,D)×keyIndexaclTensor*输入lightningIndexer结构的输入keyIndex。B: 支持泛化且与query的B保持一致。S2: 支持泛化。Nidx2: 1。D: 128。T2: 多个Batch的S2累加。FLOAT16、BFLOAT16ND(B,S2,Nidx2,D);(T2,Nidx2,D)×weightsaclTensor*输入权重B: 支持泛化且与query的B保持一致。S1: 支持泛化且与query的S1保持一致。Nidx1: 64、32、16、8。T1: 多个Batch的S1累加。FLOAT16、BFLOAT16、FLOAT32ND(B,S1,Nidx1);(T1,Nidx1)×softmaxMaxaclTensor*输入Device侧的aclTensor注意力正向计算的中间输出B: 支持泛化与query的B保持一致。N2: 等于N1。S1: 支持泛化且与query的S1保持一致。G: N1/N2。T1: 多个Batch的S1累加。FLOAT32ND(B,N2,S1,G);(N2,T1,G)×softmaxSumaclTensor*输入Device侧的aclTensor注意力正向计算的中间输出B: 支持泛化与query的B保持一致。N2: 等于N1。S1: 支持泛化且与query的S1保持一致。G: N1/N2。T1: 多个Batch的S1累加。FLOAT32ND(B,N2,S1,G);(N2,T1,G)×softmaxMaxIndexaclTensor*输入Device侧的aclTensor注意力正向计算的中间输出B: 支持泛化与query的B保持一致。Nidx2: 1。S1: 支持泛化且与query的S1保持一致。T1: 多个Batch的S1累加。FLOAT32ND(B,Nidx2,S1);(Nidx2,T1)×softmaxSumIndexaclTensor*输入Device侧的aclTensor注意力正向计算的中间输出B: 支持泛化与query的B保持一致。Nidx2: 1。S1: 支持泛化且与query的S1保持一致。T1: 多个Batch的S1累加。FLOAT32ND(B,Nidx2,S1);(Nidx2,T1)√queryRopeaclTensor*输入MLA rope部分Query位置编码的输出。与query的layout维度保持一致。B: 支持泛化与query的B保持一致。S1: 支持泛化且与query的S1保持一致。N1: 128、64、32。Dr: 64。T1: 多个Batch的S1累加。FLOAT16、BFLOAT16ND(B,S1,N1,Dr);(T1,N1,Dr)√keyRopeaclTensor*输入MLA rope部分Key位置编码的输出与key的layout维度保持一致。B: 支持泛化与query的B保持一致。S2: 支持泛化且与key的S1保持一致。N2: 等于N1。Dr: 64。T2: 多个Batch的S2累加。FLOAT16、BFLOAT16ND(B,S2,N2,Dr);(T2,N2,Dr)√actualSeqLengthsQueryaclIntArray*输入每个Batch中Query的有效token数值依赖。长度与B保持一致。累加和与T1保持一致。INT64ND(B,)-actualSeqLengthsKeyaclIntArray*输入每个Batch中Key的有效token数值依赖。长度与B保持一致。累加和T2保持一致。INT64ND(B,)-scaleValuedouble输入缩放系数建议值公式中d开根号的倒数。----layoutchar*输入layout格式仅支持BSND和TND格式。----sparseModeint64_t输入sparse的模式表示sparse的模式。sparse不同模式的详细说明请参见约束说明。仅支持模式3。----preTokensint64_t输入用于稀疏计算表示Attention需要和前几个token计算关联和Attention中的preTokens定义相同在sparseMode 0和4的时候生效默认值2^63-1。----nextTokensint64_t输入用于稀疏计算表示Attention需要和后几个token计算关联和Attention中的nextTokens定义相同在sparseMode 0和4的时候生效默认值2^63-1。----dQueryIndexaclTensor*输出QueryIndex的梯度B: 支持泛化与query的B保持一致。S1:支持泛化且与query的S1保持一致。Nidx1: 64、32、16、8。D: 128。T1: 多个Batch的S1累加。FLOAT16、BFLOAT16ND(B,S1,Nidx1,D);(T1,Nidx1,D)√dKeyIndexaclTensor*输出KeyIndex的梯度B: 支持泛化与query的B保持一致。S2: 支持泛化且与key的S2保持一致。Nidx2: 1。D: 128。T2: 多个Batch的S2累加。FLOAT16、BFLOAT16ND(B,S2,Nidx2,D);(T2,Nidx2,D)√dWeightsaclTensor*输出Weights的梯度B: 支持泛化。S1: 支持泛化不能为Matmul的M轴。Nidx1: 64、32、16、8。T1: 多个Batch的S1累加。FLOAT16、BFLOAT16、FLOAT32ND(B,S1,Nidx1);(T1,Nidx1)√lossaclTensor*输出损失函数值-FLOAT32ND(1,)-workspaceSizeuint64_t*输出返回需要在Device侧申请的workspace大小。-----executoraclOpExecutor**输出返回op执行器包含了算子计算流程。-----返回值返回aclnnStatus状态码具体参见aclnn返回码。第一段接口完成入参校验出现以下场景时报错返回值错误码描述ACLNN_ERR_PARAM_NULLPTR161001必选参数或者输出是空指针。ACLNN_ERR_PARAM_INVALID161002query、key、queryIndex、keyIndex、weights、softmaxMax等输入变量的数据类型和数据格式不在支持的范围内。ACLNN_ERR_INNER_TILING_ERROR561002多个输入tensor之间的shape不匹配详见参数说明。aclnnDenseLightningIndexerGradKLLoss参数说明参数名输入/输出描述workspace输入在Device侧申请的workspace内存地址。workspaceSize输入在Device侧申请的workspace大小由第一段接口aclnnDenseLightningIndexerGradKLLossGetWorkspaceSize获取。executor输入op执行器包含了算子计算流程。stream输入指定执行任务的Stream流。返回值返回aclnnStatus状态码具体参见aclnn返回码。约束说明参数query、key、queryIndex、keyIndex的数据类型应保持一致。参数weights不为float32时参数query、key、queryIndex、keyIndex、weights的数据类型应保持一致。公共约束确定性计算 aclnnDenseLightningIndexerGradKLLoss默认非确定性实现支持通过alcrtCtxSetSysParamOpt开启确定性。入参为空的场景处理query或key或query_index或key_index或weight为空Tensor当前不支持会报错。sparseMode含义备注0defaultMask模式如果attenmask未传入则不做mask操作忽略preTokens和nextTokens如果传入则需要传入完整的attenmask矩阵表示preTokens和nextTokens之间的部分需要计算不支持1allMask必须传入完整的attenmask矩阵不支持2leftUpCausal模式的mask需要传入优化后的attenmask矩阵不支持3rightDownCausal模式的mask对应以右顶点为划分的下三角场景需要传入优化后的attenmask矩阵支持4band模式的mask需要传入优化后的attenmask矩阵不支持5prefix不支持6global不支持7dilated不支持8block_local不支持规格约束规格项规格规格说明B1~256-S1、S21~128KS1、S2支持不等长N132、64、128-Nidx18、16、32、64-N232、64、128-Nidx21-D128query与query_index的D相同。Drope64-layoutBSND/TND-典型值规格项典型值queryN1128/64/32; D128queryIndexNidx1 64/32/16/8; D 128 ; S1 64k/128kkeyIndexD 128qRopeD 64调用示例调用示例代码如下仅供参考具体编译和执行过程请参考编译与运行样例。#include iostream #include vector #include cstdint #include cmath #include acl/acl.h #include aclnnop/aclnn_dense_lightning_indexer_grad_kl_loss.h #define CHECK_RET(cond, return_expr) \ do { \ if (!(cond)) { \ return_expr; \ } \ } while (0) #define LOG_PRINT(message, ...) \ do { \ printf(message, ##__VA_ARGS__); \ } while (0) int64_t GetShapeSize(const std::vectorint64_t shape) { int64_t shapeSize 1; for (auto i : shape) { shapeSize * i; } return shapeSize; } void PrintOutResult(std::vectorint64_t shape, void** deviceAddr) { auto size GetShapeSize(shape); std::vectoraclFloat16 resultData(size, 0); auto ret aclrtMemcpy(resultData.data(), resultData.size() * sizeof(resultData[0]), *deviceAddr, size * sizeof(resultData[0]), ACL_MEMCPY_DEVICE_TO_HOST); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT(copy result from device to host failed. ERROR: %d\n, ret); return); for (int64_t i 0; i size; i) { LOG_PRINT(mean result[%ld] is: %f\n, i, aclFloat16ToFloat(resultData[i])); } } int Init(int32_t deviceId, aclrtContext* context, aclrtStream* stream) { // 固定写法AscendCL初始化 auto ret aclInit(nullptr); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT(aclInit failed. ERROR: %d\n, ret); return ret); ret aclrtSetDevice(deviceId); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT(aclrtSetDevice failed. ERROR: %d\n, ret); return ret); ret aclrtCreateContext(context, deviceId); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT(aclrtCreateContext failed. ERROR: %d\n, ret); return ret); ret aclrtSetCurrentContext(*context); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT(aclrtSetCurrentContext failed. ERROR: %d\n, ret); return ret); ret aclrtCreateStream(stream); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT(aclrtCreateStream failed. ERROR: %d\n, ret); return ret); return 0; } template typename T int CreateAclTensor(const std::vectorT hostData, const std::vectorint64_t shape, void** deviceAddr, aclDataType dataType, aclTensor** tensor) { auto size GetShapeSize(shape) * sizeof(T); // 调用aclrtMalloc申请device侧内存 auto ret aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT(aclrtMalloc failed. ERROR: %d\n, ret); return ret); // 调用aclrtMemcpy将host侧数据拷贝到device侧内存上 ret aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT(aclrtMemcpy failed. ERROR: %d\n, ret); return ret); // 计算连续tensor的strides std::vectorint64_t strides(shape.size(), 1); for (int64_t i shape.size() - 2; i 0; i--) { strides[i] shape[i 1] * strides[i 1]; } // 调用aclCreateTensor接口创建aclTensor *tensor aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND, shape.data(), shape.size(), *deviceAddr); return 0; } int main() { // 1. 固定写法device/context/stream初始化参考AscendCL对外接口列表 // 根据自己的实际device填写deviceId int32_t deviceId 0; aclrtContext context; aclrtStream stream; auto ret Init(deviceId, context, stream); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT(Init acl failed. ERROR: %d\n, ret); return ret); // 2. 构造输入与输出需要根据API的接口自定义构造 int64_t s1 1; int64_t s2 1; int64_t n1 32; int64_t n2 n1; int64_t n1Index 8; int64_t n2Index 1; int64_t dQuery 128; int64_t dRope 64; int64_t dQueryIndex 128; int64_t t1 s1; int64_t t2 s2; int64_t G n1 / n2; std::vectorint64_t qShape {t1, n1, dQuery}; std::vectorint64_t kShape {t2, n2, dQuery}; std::vectorint64_t qRopeShape {t1, n1, dRope}; std::vectorint64_t kRopeShape {t2, n2, dRope}; std::vectorint64_t qIndexShape {t1, n1Index, dQueryIndex}; std::vectorint64_t kIndexShape {t2, n2Index, dQueryIndex}; std::vectorint64_t weightShape {t1, n1Index}; std::vectorint64_t softmaxMaxShape {n2, t1, G}; std::vectorint64_t softmaxSumShape {n2, t1, G}; std::vectorint64_t softmaxMaxIndexShape {n2Index, t1}; std::vectorint64_t softmaxSumIndexShape {n2Index, t1}; std::vectorint64_t dQIndexShape {t1, n1Index, dQueryIndex}; std::vectorint64_t dKIndexShape {t2, n2Index, dQueryIndex}; std::vectorint64_t dWeightShape {t1, n1Index}; std::vectorint64_t lossShape {1}; void* qDeviceAddr nullptr; void* kDeviceAddr nullptr; void* qRopeDeviceAddr nullptr; void* kRopeDeviceAddr nullptr; void* qIndexDeviceAddr nullptr; void* kIndexDeviceAddr nullptr; void* weightDeviceAddr nullptr; void* softmaxMaxDeviceAddr nullptr; void* softmaxSumDeviceAddr nullptr; void* softmaxMaxIndexDeviceAddr nullptr; void* softmaxSumIndexDeviceAddr nullptr; void* dQIndexDeviceAddr nullptr; void* dKIndexDeviceAddr nullptr; void* dWeightDeviceAddr nullptr; void* lossDeviceAddr nullptr; aclTensor* q nullptr; aclTensor* k nullptr; aclTensor* qRope nullptr; aclTensor* kRope nullptr; aclTensor* qIndex nullptr; aclTensor* kIndex nullptr; aclTensor* weight nullptr; aclTensor* softmaxMax nullptr; aclTensor* softmaxSum nullptr; aclTensor* softmaxMaxIndex nullptr; aclTensor* softmaxSumIndex nullptr; aclTensor* dQIndex nullptr; aclTensor* dKIndex nullptr; aclTensor* dWeight nullptr; aclTensor* loss nullptr; std::vectoraclFloat16 qHostData(t1 * n1 * dQuery, aclFloatToFloat16(0.1)); std::vectoraclFloat16 kHostData(t2 * n2 * dQuery, aclFloatToFloat16(0.2)); std::vectoraclFloat16 qRopeHostData(t1 * n1 * dRope, aclFloatToFloat16(0.1)); std::vectoraclFloat16 kRopeHostData(t2 * n2 * dRope, aclFloatToFloat16(0.2)); std::vectoraclFloat16 qIndexHostData(t1 * n1Index * dQueryIndex, aclFloatToFloat16(0.2)); std::vectoraclFloat16 kIndexHostData(t2 * n2Index * dQueryIndex, aclFloatToFloat16(0.1)); std::vectoraclFloat16 weightHostData(t1 * n1Index, aclFloatToFloat16(0.005)); std::vectorfloat softmaxMaxHostData(t1 * n2, 25.4483f); std::vectorfloat softmaxSumHostData(t1 * n2, 1.0f); std::vectorfloat softmaxMaxIndexHostData(t1 * n2Index, 25.4483f); std::vectorfloat softmaxSumIndexHostData(t1 * n2Index, 1.0f); std::vectoraclFloat16 dQIndexHostData(t1 * n1Index * dQueryIndex); std::vectoraclFloat16 dKIndexHostData(t2 * n2Index * dQueryIndex); std::vectoraclFloat16 dWeightHostData(t1 * n1Index); std::vectorfloat lossHostData(1, 1.0f); ret CreateAclTensor(qHostData, qShape, qDeviceAddr, aclDataType::ACL_FLOAT16, q); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(kHostData, kShape, kDeviceAddr, aclDataType::ACL_FLOAT16, k); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(qRopeHostData, qRopeShape, qRopeDeviceAddr, aclDataType::ACL_FLOAT16, qRope); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(kRopeHostData, kRopeShape, kRopeDeviceAddr, aclDataType::ACL_FLOAT16, kRope); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(qIndexHostData, qIndexShape, qIndexDeviceAddr, aclDataType::ACL_FLOAT16, qIndex); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(kIndexHostData, kIndexShape, kIndexDeviceAddr, aclDataType::ACL_FLOAT16, kIndex); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(weightHostData, weightShape, weightDeviceAddr, aclDataType::ACL_FLOAT16, weight); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(softmaxMaxHostData, softmaxMaxShape, softmaxMaxDeviceAddr, aclDataType::ACL_FLOAT, softmaxMax); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(softmaxSumHostData, softmaxSumShape, softmaxSumDeviceAddr, aclDataType::ACL_FLOAT, softmaxSum); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(softmaxMaxIndexHostData, softmaxMaxIndexShape, softmaxMaxIndexDeviceAddr, aclDataType::ACL_FLOAT, softmaxMaxIndex); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(softmaxSumIndexHostData, softmaxSumIndexShape, softmaxSumIndexDeviceAddr, aclDataType::ACL_FLOAT, softmaxSumIndex); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(dQIndexHostData, dQIndexShape, dQIndexDeviceAddr, aclDataType::ACL_FLOAT16, dQIndex); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(dKIndexHostData, dKIndexShape, dKIndexDeviceAddr, aclDataType::ACL_FLOAT16, dKIndex); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(dWeightHostData, dWeightShape, dWeightDeviceAddr, aclDataType::ACL_FLOAT16, dWeight); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(lossHostData, lossShape, lossDeviceAddr, aclDataType::ACL_FLOAT, loss); CHECK_RET(ret ACL_SUCCESS, return ret); std::vectorint64_t acSeqQLenOp {t1}; std::vectorint64_t acSeqKvLenOp {t2}; aclIntArray* acSeqQLen aclCreateIntArray(acSeqQLenOp.data(), acSeqQLenOp.size()); aclIntArray* acSeqKvLen aclCreateIntArray(acSeqKvLenOp.data(), acSeqKvLenOp.size()); float scaleValue 1.0 / sqrt(dQuery); int64_t preTokens 2147483647; int64_t nextTokens 2147483647; int64_t sparseMode 3; bool deterministic false; char layOut[5] {T, N, D, 0}; // 3. 调用CANN算子库API需要修改为具体的Api名称 uint64_t workspaceSize 0; aclOpExecutor* executor; // 调用aclnnDenseLightningIndexerGradKLLossGetWorkspaceSize第一段接口 ret aclnnDenseLightningIndexerGradKLLossGetWorkspaceSize( q, k, qIndex, kIndex, weight, softmaxMax, softmaxSum, softmaxMaxIndex, softmaxSumIndex, qRope, kRope, acSeqQLen, acSeqKvLen, scaleValue, layOut, sparseMode, preTokens, nextTokens, dQIndex, dKIndex, dWeight, loss, workspaceSize, executor); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT(aclnnDenseLightningIndexerGradKLLossGetWorkspaceSize failed. ERROR: %d\n, ret); return ret); // 根据第一段接口计算出的workspaceSize申请device内存 void* workspaceAddr nullptr; if (workspaceSize 0) { ret aclrtMalloc(workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT(allocate workspace failed. ERROR: %d\n, ret); return ret); } // 调用aclnnDenseLightningIndexerGradKLLoss第二段接口 ret aclnnDenseLightningIndexerGradKLLoss(workspaceAddr, workspaceSize, executor, stream); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT(aclnnDenseLightningIndexerGradKLLoss failed. ERROR: %d\n, ret); return ret); // 4. 固定写法同步等待任务执行结束 ret aclrtSynchronizeStream(stream); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT(aclrtSynchronizeStream failed. ERROR: %d\n, ret); return ret); // 5. 获取输出的值将device侧内存上的结果拷贝至host侧需要根据具体API的接口定义修改 PrintOutResult(dQIndexShape, dQIndexDeviceAddr); PrintOutResult(dKIndexShape, dKIndexDeviceAddr); PrintOutResult(dWeightShape, dWeightDeviceAddr); PrintOutResult(lossShape, lossDeviceAddr); // 6. 释放aclTensor和aclScalar需要根据具体API的接口定义修改 aclDestroyTensor(q); aclDestroyTensor(k); aclDestroyTensor(qIndex); aclDestroyTensor(kIndex); aclDestroyTensor(qRope); aclDestroyTensor(kRope); aclDestroyTensor(weight); aclDestroyTensor(softmaxMax); aclDestroyTensor(softmaxSum); aclDestroyTensor(softmaxMaxIndex); aclDestroyTensor(softmaxSumIndex); aclDestroyTensor(dQIndex); aclDestroyTensor(dKIndex); aclDestroyTensor(dWeight); aclDestroyTensor(loss); // 7. 释放device资源 aclrtFree(qDeviceAddr); aclrtFree(kDeviceAddr); aclrtFree(qIndexDeviceAddr); aclrtFree(kIndexDeviceAddr); aclrtFree(qRopeDeviceAddr); aclrtFree(kRopeDeviceAddr); aclrtFree(weightDeviceAddr); aclrtFree(softmaxMaxDeviceAddr); aclrtFree(softmaxSumDeviceAddr); aclrtFree(softmaxMaxIndexDeviceAddr); aclrtFree(softmaxSumIndexDeviceAddr); aclrtFree(dQIndexDeviceAddr); aclrtFree(dKIndexDeviceAddr); aclrtFree(dWeightDeviceAddr); aclrtFree(lossDeviceAddr); if (workspaceSize 0) { aclrtFree(workspaceAddr); } aclrtDestroyStream(stream); aclrtDestroyContext(context); aclrtResetDevice(deviceId); aclFinalize(); return 0; }【免费下载链接】ops-transformer本项目是CANN提供的transformer类大模型算子库实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-transformer创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考