leon 5 luni în urmă
părinte
comite
a116750e37
1 a modificat fișierele cu 5 adăugiri și 4 ștergeri
  1. 5 4
      src/resnet.cu

+ 5 - 4
src/resnet.cu

@@ -1,5 +1,6 @@
 #include "infer.hpp"
 #include "resnet.hpp"
+#include <cfloat>
 
 namespace resnet
 {
@@ -227,7 +228,7 @@ static __global__ void softmax(float *predict, int length)
     // 1. 找到最大值,存储在共享内存中
     float max_val = -FLT_MAX;
     for (int i = tid; i < length; i += blockDim.x) {
-        max_val = max(max_val, data[i]);
+        max_val = max(max_val, predict[i]);
     }
     shared_data[tid] = max_val;
     __syncthreads();
@@ -244,8 +245,8 @@ static __global__ void softmax(float *predict, int length)
     // 2. 计算指数并求和
     float sum_exp = 0.0f;
     for (int i = tid; i < length; i += blockDim.x) {
-        data[i] = expf(data[i] - max_val);
-        sum_exp += data[i];
+        predict[i] = expf(data[i] - max_val);
+        sum_exp += predict[i];
     }
     shared_data[tid] = sum_exp;
     __syncthreads();
@@ -261,7 +262,7 @@ static __global__ void softmax(float *predict, int length)
 
     // 3. 每个元素除以总和,得到 softmax 值
     for (int i = tid; i < length; i += blockDim.x) {
-        data[i] /= total_sum;
+        predict[i] /= total_sum;
     }
 }