1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
|
#include <torch/extension.h> #include <cuda_runtime.h>
__global__ void fused_forward( const float* input, float* output, const float* W1, const float* b1, const float* W2, const float* b2, const float* W3, const float* b3, int batch_size, int in_dim, int hid1, int hid2, int out_dim ) { int sample_idx = blockIdx.x * blockDim.x + threadIdx.x; if (sample_idx >= batch_size) return;
const float* x = input + sample_idx * in_dim; float* out = output + sample_idx * out_dim;
float hidden1[128]; for (int i = 0; i < hid1; ++i) { float sum = b1[i]; for (int j = 0; j < in_dim; ++j) { sum += x[j] * W1[j * hid1 + i]; } hidden1[i] = fmaxf(sum, 0.0f); }
float hidden2[64]; for (int i = 0; i < hid2; ++i) { float sum = b2[i]; for (int j = 0; j < hid1; ++j) { sum += hidden1[j] * W2[j * hid2 + i]; } hidden2[i] = fmaxf(sum, 0.0f); }
for (int i = 0; i < out_dim; ++i) { float sum = b3[i]; for (int j = 0; j < hid2; ++j) { sum += hidden2[j] * W3[j * out_dim + i]; } out[i] = sum; } }
torch::Tensor fused_forward_cuda( torch::Tensor input, torch::Tensor W1, torch::Tensor b1, torch::Tensor W2, torch::Tensor b2, torch::Tensor W3, torch::Tensor b3 ) { int batch_size = input.size(0); int in_dim = W1.size(1); int hid1 = W1.size(0); int hid2 = W2.size(0); int out_dim = W3.size(0);
torch::Tensor output = torch::zeros({batch_size, out_dim}, input.options());
int threads = 256; int blocks = (batch_size + threads - 1) / threads;
fused_forward<<<blocks, threads>>>( input.data_ptr<float>(), output.data_ptr<float>(), W1.data_ptr<float>(), b1.data_ptr<float>(), W2.data_ptr<float>(), b2.data_ptr<float>(), W3.data_ptr<float>(), b3.data_ptr<float>(), batch_size, in_dim, hid1, hid2, out_dim );
return output; }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_forward", &fused_forward_cuda, "Fused forward pass (CUDA)"); }
|