#include <CL/sycl.hpp>
#include <array>
#include <iostream>
#include <iomanip>
#include <cmath>

using namespace cl::sycl;

constexpr float KILL_RATE { 0.062f };
constexpr float FEED_RATE { 0.03f };
constexpr float DT { 1.0f };

constexpr float DIFFUSION_RATE_U { 0.1f };
constexpr float DIFFUSION_RATE_V { 0.05f };

//constexpr std::size_t nb_rows { 100ul }; //{ 1080ul };
//constexpr std::size_t nb_cols { 200ul }; //{ 1920ul };
//constexpr std::size_t nb_images { 5ul };
//constexpr std::size_t nb_iterations { 1000ul }; // 680000

void submit( queue & q, 
    buffer<float,2> & iub, buffer<float,2> & ivb,
    buffer<float,2> & oub, buffer<float,2> & ovb,
    std::size_t nb_rows, std::size_t nb_cols ) {
    // Submit command group for execution
    q.submit([&](handler& h) {

        // Create accessors for input and output buffers
        accessor iua(iub, h, read_only);
        accessor iva(ivb, h, read_only);
        accessor oua(oub, h, write_only);
        accessor ova(ovb, h, write_only);

        // Define the kernel
        h.parallel_for(range<2>{nb_rows,nb_cols}, [=](item<2> it) {

            id<2> xy = it.get_id();
            std::size_t row = xy[0] ;
            std::size_t col = xy[1] ;

			float u = iua[id<2>{row+1,col+1}];
			float v = iva[id<2>{row+1,col+1}];
			float uvv = u*v*v;

			float full_u = 0.0f;
            float full_v = 0.0f;
			for(long k = 0l; k < 3l; ++k){
				for(long l = 0l; l < 3l; ++l){
					full_u += (iua[id<2>{row+k,col+l}] - u);
					full_v += (iva[id<2>{row+k,col+l}] - v);

			float du = DIFFUSION_RATE_U*full_u - uvv + FEED_RATE*(1.0f - u);
			float dv = DIFFUSION_RATE_V*full_v + uvv - (FEED_RATE + KILL_RATE)*v;

			oua[id<2>{row+1,col+1}] = u + du*DT;
			ova[id<2>{row+1,col+1}] = v + dv*DT;


    // Wait for the command group to finish

int main( int argc, char * argv[] ) {

    // runtime parameters
    assert(argc=5) ;
    std::size_t nb_rows {std::stoul(argv[1])} ;
    std::size_t nb_cols {std::stoul(argv[2])} ;
    std::size_t nb_images {std::stoul(argv[3])} ;
    std::size_t nb_iterations {std::stoul(argv[4])} ;
    assert(nb_iterations % 2 == 0); // nb_iterations must be even

    try {

        // Loop through available platforms and devices
        for (auto const& this_platform : platform::get_platforms() ) {
            std::cout << "Found platform: "
                << this_platform.get_info<info::platform::name>() << std::endl;
            for (auto const& this_device : this_platform.get_devices() ) {
                std::cout << "  Device: "
                    << this_device.get_info<info::device::name>() << std::endl;
        // Create SYCL queue
        queue q;
        // Running platform and device
        std::cout << "Running on platform: "
            << q.get_device().get_platform().get_info<info::platform::name>() << std::endl;
        std::cout << "  Device: "
            << q.get_device().get_info<info::device::name>() << std::endl;
        std::cout << std::endl;
        // Initialize input array
        const std::size_t padded_nb_rows { nb_rows+2 };
        const std::size_t padded_nb_cols { nb_cols+2 };
        const std::size_t size { padded_nb_rows*padded_nb_cols };
        std::vector<float> u1(size);
        std::vector<float> v1(size);
        std::vector<float> u2(size);
        std::vector<float> v2(size);
        for (int i = 0; i < padded_nb_rows; i++) {
            for (int j = 0; j < padded_nb_cols; j++) {
                u1[i*padded_nb_cols+j] = 1.f;
                v1[i*padded_nb_cols+j] = 0.f;
                u2[i*padded_nb_cols+j] = 1.f;
                v2[i*padded_nb_cols+j] = 0.f;
        const std::size_t v_row_begin { (7ul*padded_nb_rows+8ul)/16ul };
        const std::size_t v_row_end { (9ul*padded_nb_rows+8ul)/16ul };
        const std::size_t v_col_begin { (7ul*padded_nb_cols+8ul)/16ul };
        const std::size_t v_col_end { (9ul*padded_nb_cols+8ul)/16ul };
        std::cout << "v_row_begin: " << v_row_begin << std::endl;
        std::cout << "v_row_end:   " << v_row_end   << std::endl;
        std::cout << "v_col_begin: " << v_col_begin << std::endl;
        std::cout << "v_col_end:   " << v_col_end   << std::endl;
        std::cout << std::endl;
        for (int i = v_row_begin; i < v_row_end; i++) {
            for (int j = v_col_begin; j < v_col_end; j++) {
                u1[i*padded_nb_cols+j] = 0.f;
                v1[i*padded_nb_cols+j] = 1.f;
        // Create buffers
        buffer<float,2> u1b { u1.data(),range<2>{padded_nb_rows,padded_nb_cols} };
        buffer<float,2> v1b { v1.data(),range<2>{padded_nb_rows,padded_nb_cols} };
        buffer<float,2> u2b { u2.data(),range<2>{padded_nb_rows,padded_nb_cols} };
        buffer<float,2> v2b { v2.data(),range<2>{padded_nb_rows,padded_nb_cols} };
        // iterations
        for ( std::size_t image = 0 ; image < nb_images ; ++image ) {
            for ( std::size_t iter = 0 ; iter < nb_iterations ; iter += 2 ) {
                submit( q, u1b, v1b, u2b, v2b, nb_rows, nb_cols );
                submit( q, u2b, v2b, u1b, v1b, nb_rows, nb_cols );
        // Print some result
        const std::size_t row_center { padded_nb_rows/2ul };
        const std::size_t col_center { padded_nb_cols/2ul };
        host_accessor u1ha{u1b, read_only};
        std::cout<<std::fixed<<std::setprecision(2) ;
        for (std::size_t i = (row_center-5ul) ; i < (row_center+5ul); i++) {
            for (std::size_t j = (col_center-5ul); j < (col_center+5ul); j++) {
                std::cout << u1ha[id<2>{i,j}] << " ";
            std::cout << "\n";
        std::cout << std::endl;
        host_accessor v1ha{v1b, read_only};
        std::cout<<std::fixed<<std::setprecision(2) ;
        for (std::size_t i = (row_center-5ul) ; i < (row_center+5ul); i++) {
            for (std::size_t j = (col_center-5ul); j < (col_center+5ul); j++) {
                std::cout << v1ha[id<2>{i,j}] << " ";
            std::cout << "\n";
        std::cout << std::endl;
    catch (sycl::exception & e) {
      std::cout << e.what() << std::endl;
      std::cout << e.category() << std::endl;
      std::cout << e.code() << std::endl;
    catch (std::exception & e) {
      std::cout << e.what() << std::endl;
    catch (const char * e) {
      std::cout << e << std::endl;

    return 0;