/*
 * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES
 * SPDX-License-Identifier: BSD-3-Clause
 */

#include "nccl.h"
#include "cudss.h"

extern "C" {

static inline size_t cuda_sizeof_type(cudaDataType_t type) {
    switch(type) {
    case CUDA_R_32F: return sizeof(float);
    case CUDA_R_64F: return sizeof(double);
    case CUDA_R_32I: return sizeof(int);
    case CUDA_R_64I: return sizeof(int64_t);
    default: return 0;
    }
}

static inline ncclDataType_t cuda_to_nccl_type(cudaDataType_t type) {
    switch(type) {
    case CUDA_R_32F: return ncclFloat32;
    case CUDA_R_64F: return ncclFloat64;
    case CUDA_R_32I: return ncclInt32;
    case CUDA_R_64I: return ncclInt64;
    //TODO: Proper error
    default: exit(1);
    }
}

static inline ncclRedOp_t cudss_to_nccl_op(cudssOpType_t op) {
    switch(op) {
    case CUDSS_SUM: return ncclSum;
    case CUDSS_MAX: return ncclMax;
    case CUDSS_MIN: return ncclMin;
    //TODO: Proper error
    default: exit(1);
    }
}

int cudssCommRank(void *comm, int *rank)
{
    return ncclCommUserRank(*((ncclComm_t*)comm), rank);
}

int cudssCommSize(void *comm, int *size)
{
    return ncclCommCount(*((ncclComm_t*)comm), size);
}

int cudssSend(const void *buffer, int count, cudaDataType_t datatype, int dest,
    int tag, void *comm, cudaStream_t stream)
{
    return ncclSend(buffer, count, cuda_to_nccl_type(datatype), dest,
        *((ncclComm_t*)comm), stream);
}

int cudssRecv(void *buffer, int count, cudaDataType_t datatype, int root,
    int tag, void *comm, cudaStream_t stream)
{
    return ncclRecv(buffer, count, cuda_to_nccl_type(datatype), root,
        *((ncclComm_t*)comm), stream);
}

int cudssBcast(void *buffer, int count, cudaDataType_t datatype, int root,
    void *comm, cudaStream_t stream)
{
    return ncclBcast(buffer, count, cuda_to_nccl_type(datatype), root,
        *((ncclComm_t*)comm), stream);
}

int cudssReduce(const void *sendbuf, void *recvbuf, int count,
    cudaDataType_t datatype, cudssOpType_t op, int root, void *comm,
    cudaStream_t stream)
{
    return ncclReduce(sendbuf, recvbuf, count, cuda_to_nccl_type(datatype),
        cudss_to_nccl_op(op), root, *((ncclComm_t*)comm), stream);
}

int cudssAllreduce(const void *sendbuf, void *recvbuf, int count,
    cudaDataType_t datatype, cudssOpType_t op, void *comm, cudaStream_t stream)
{
    return ncclAllReduce(sendbuf, recvbuf, count, cuda_to_nccl_type(datatype),
        cudss_to_nccl_op(op), *((ncclComm_t*)comm), stream);
}

int cudssScatterv(const void *sendbuf, const int *sendcounts,
    const int *displs, cudaDataType_t sendtype, void *recvbuf, int recvcount,
    cudaDataType_t recvtype, int root, void *comm, cudaStream_t stream)
{
    size_t bytes_per_send_elt = cuda_sizeof_type(sendtype);
    ncclComm_t *nccl_comm = (ncclComm_t*)comm;

    int mpi_size = 1, mpi_rank = 0;
    ncclCommUserRank(*nccl_comm, &mpi_rank);
    ncclCommCount   (*nccl_comm, &mpi_size);

    for (int peer = 0; peer < mpi_size; peer++) {
        if (mpi_rank == root) {
            if (peer == root) {
                cudaMemcpyAsync(recvbuf, (char*)sendbuf + displs[peer] * bytes_per_send_elt,
                    sendcounts[peer] * bytes_per_send_elt, cudaMemcpyDeviceToDevice, stream);
            } else {
                ncclSend((char*)sendbuf + displs[peer] * bytes_per_send_elt,
                    sendcounts[peer], cuda_to_nccl_type(sendtype), peer,
                    *nccl_comm, stream);
            }
        } else if (mpi_rank == peer) {
            ncclRecv(recvbuf, recvcount, cuda_to_nccl_type(recvtype), root,
                *nccl_comm, stream);
            break;
        }
    }
    return 0;
}

int cudssCommSplit(const void *comm, int color, int key, void *newcomm)
{
    return ncclCommSplit(*((ncclComm_t*)comm), color, key, (ncclComm_t*)newcomm,
        NULL/*nccl_config*/);
}

int cudssCommFree(void *comm)
{
    return ncclCommDestroy(*((ncclComm_t*)comm));
}

/*
 * Distributed communication service API wrapper binding table (imported by cuDSS).
 * The exposed C symbol must be named as "cudssDistributedInterface".
 */
cudssDistributedInterface_t cudssDistributedInterface = {
    cudssCommRank,
    cudssCommSize,
    cudssSend,
    cudssRecv,
    cudssBcast,
    cudssReduce,
    cudssAllreduce,
    cudssScatterv,
    cudssCommSplit,
    cudssCommFree
};

} // extern "C"
