#include <stdlib.h>
#include <iostream>
#include <iomanip>
#include "utility.h"
#include "cuda_runtime.h"
using namespace std;

__global__ void vecAdd(int *x, int *y, int *z, int N) {
    int tid = blockDim.x * blockIdx.x + threadIdx.x;
    int nthreads = blockDim.x * gridDim.x;
    int chunk_size = N / nthreads;
    
    int start = tid * chunk_size;
    int end = (tid + 1) * chunk_size;
    if (tid == nthreads-1) {
        int extra = N - nthreads * chunk_size;
        if (extra != 0)
            end += extra;
    }

    for(int i = start; i < end; i+=1) {
        z[i] = x[i] + y[i];
    }
}

int main(int argc, const char **argv) {
    int N = atoi(argv[1]);

    // Host Memory
    int *x = (int *)malloc(N * sizeof(int));
    int *y = (int *)malloc(N * sizeof(int));
    int *z = (int *)malloc(N * sizeof(int));

    init_mat(x, 1, N, 0);
    init_mat(y, 1, N, 1);
    init_mat(z, 1, N, -1);

    set_clock();

    // GPU Memory
    int *x_d, *y_d, *z_d;
    cudaMalloc((void**) &x_d, N * sizeof(int));
    cudaMalloc((void**) &y_d, N * sizeof(int));
    cudaMalloc((void**) &z_d, N * sizeof(int));

    cudaMemcpy(x_d, x, N * sizeof(int), cudaMemcpyHostToDevice);
    cudaMemcpy(y_d, y, N * sizeof(int), cudaMemcpyHostToDevice);

    double h2dtime = elapsed_time();

    // Compute
    vecAdd<<<4096, 1024>>>(x_d, y_d, z_d, N);
    cudaDeviceSynchronize();

    double exectime = elapsed_time() - h2dtime;

    // Copy result back
    cudaMemcpy(z, z_d, N * sizeof(int), cudaMemcpyDeviceToHost);

    double d2htime = elapsed_time() - (h2dtime + exectime);
    printf("Time taken: %.4f secs\n", (h2dtime + exectime + d2htime));
    printf("Time taken for H2D: %.4f secs\n", h2dtime);
    printf("Time taken for compute: %.4f secs\n", exectime);
    printf("Time taken for D2H: %.4f secs\n", d2htime);

    for(int i = 0; i < N; i++) {
        if(z[i] != x[i] + y[i])
            printf("Error in code! For i = %d, %d != %d + %d\n", i, z[i], x[i], y[i]);
    }

    free(x); free(y); free(z);
    cudaFree(x_d); cudaFree(y_d); cudaFree(z_d);

    return 0;
}