/*****************************************************************************
 * $CAMITK_LICENCE_BEGIN$
 *
 * CamiTK - Computer Assisted Medical Intervention ToolKit
 * (c) 2001-2025 Univ. Grenoble Alpes, CNRS, Grenoble INP - UGA, TIMC, 38000 Grenoble, France
 *
 * Visit http://camitk.imag.fr for more information
 *
 * This file is part of CamiTK.
 *
 * CamiTK is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License version 3
 * only, as published by the Free Software Foundation.
 *
 * CamiTK is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License version 3 for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * version 3 along with CamiTK.  If not, see <http://www.gnu.org/licenses/>.
 *
 * $CAMITK_LICENCE_END$
 ****************************************************************************/

#include "transformation_manager_bindings.h"
#include "core_utils.h"
#include "qt_bindings.h"
#include "numpy_utils.h"
#include "docstrings.h"

#include <TransformationManager.h>


namespace pybind11::detail {

// --------------- QVector of FrameOfReference ---------------
// Conversion from Python PyObject to C++ QVector<camitk::FrameOfReference*>
bool type_caster<QVector<camitk::FrameOfReference*>>::load(handle src, bool /* not used : indicates whether implicit conversions should be applied */) {
    // Accept list or tuple
    if (!py::isinstance<py::sequence>(src)) {
        return false;
    }

    py::sequence seq = reinterpret_borrow<py::sequence>(src);
    value.clear();
    for (auto f : seq) {
        if (f.is_none()) {
            value.append(nullptr);
        }
        else {
            camitk::FrameOfReference* ptr = nullptr;

            // Try casting to camitk::Component
            try {
                ptr = f.cast<camitk::FrameOfReference*>();
            }
            catch (const py::cast_error&) {
                return false;  // Invalid item in list
            }

            value.append(ptr);
        }
    }

    return true;
}

// Conversion from C++ QVector<camitk::FrameOfReference*> to Python PyObject
handle type_caster<QVector<camitk::FrameOfReference*>>::cast(const QVector<camitk::FrameOfReference*>& list, return_value_policy /* policy */, handle /* parent */) {
    py::list pyList;

    for (camitk::FrameOfReference* f : list) {
        pyList.append(py::cast(f, py::return_value_policy::reference));  // <- this will do the dynamic cast
    }

    return pyList.release();
}

// --------------- QVector of Transformation ---------------
// Conversion from Python PyObject to C++ QVector<camitk::Transformation*>
bool type_caster<QVector<camitk::Transformation*>>::load(handle src, bool /* not used : indicates whether implicit conversions should be applied */) {
    // Accept list or tuple
    if (!py::isinstance<py::sequence>(src)) {
        return false;
    }

    py::sequence seq = reinterpret_borrow<py::sequence>(src);
    value.clear();
    for (auto f : seq) {
        if (f.is_none()) {
            value.append(nullptr);
        }
        else {
            camitk::Transformation* ptr = nullptr;

            // Try casting to camitk::Component
            try {
                ptr = f.cast<camitk::Transformation*>();
            }
            catch (const py::cast_error&) {
                return false;  // Invalid item in list
            }

            value.append(ptr);
        }
    }

    return true;
}

// Conversion from C++ QVector<camitk::Transformation*> to Python PyObject
handle type_caster<QVector<camitk::Transformation*>>::cast(const QVector<camitk::Transformation*>& list, return_value_policy /* policy */, handle /* parent */) {
    py::list pyList;

    for (camitk::Transformation* f : list) {
        pyList.append(py::cast(f, py::return_value_policy::reference));  // <- this will do the dynamic cast
    }

    return pyList.release();
}

// --------------- vtkSmartPointer<vtkTransform> ---------------
// Conversion from Python PyObject to C++ vtkSmartPointer<vtkTransform>
bool type_caster<vtkSmartPointer<vtkTransform>>::load(handle src, bool) {
    try {
        auto array = py::array_t<double>::ensure(src);
        if (!array || array.ndim() != 2 || array.shape(0) != 4 || array.shape(1) != 4) {
            return false;
        }

        auto* data = static_cast<double*>(array.request().ptr);

        value = vtkSmartPointer<vtkTransform>::New();
        value->SetMatrix(data);
        return true;
    }
    catch (...) {
        return false;
    }
}

// Conversion from C++ to Python PyObject
handle type_caster<vtkSmartPointer<vtkTransform>>::cast(vtkSmartPointer<vtkTransform> src,
        return_value_policy /* policy */,
handle /* parent */) {
    if (!src) {
        return py::none().release();
    }

    vtkMatrix4x4* m = src->GetMatrix();
    std::array<double, 16> flat{};
    for (int i = 0; i < 4; ++i)
        for (int j = 0; j < 4; ++j) {
            flat[i * 4 + j] = m->GetElement(i, j);
        }

    return py::array_t<double>({4, 4}, flat.data()).release();
}

} // namespace pybind11::detail


void add_transformation_manager_bindings(py::module_& m) {

    // --------------- TransformationManager ---------------

    py::class_<camitk::TransformationManager> transformationManager(m, "TransformationManager",
            DOC(camitk_TransformationManager));

    transformationManager.def_static("getWorldFrame", &camitk::TransformationManager::getWorldFrame,
                                     py::return_value_policy::reference,
                                     DOC(camitk_TransformationManager_getWorldFrame));

    transformationManager.def_static("getFramesOfReference", &camitk::TransformationManager::getFramesOfReference,
                                     py::return_value_policy::reference,
                                     DOC(camitk_TransformationManager_getFramesOfReference));

    transformationManager.def_static("getTransformations", &camitk::TransformationManager::getTransformations,
                                     py::return_value_policy::reference,
                                     DOC(camitk_TransformationManager_getTransformations));

    transformationManager.def_static("addFrameOfReference", [](QString name, QString description) {
        std::shared_ptr<camitk::FrameOfReference> newFrame = camitk::TransformationManager::addFrameOfReference(name, description);
        return newFrame.get(); // warning: the new frame has to be used before the TransformationManager does any cleanup
    },
    py::arg("name"),
    py::arg("description") = "",
    py::return_value_policy::reference,
    DOC(camitk_TransformationManager_addFrameOfReference_1));

    transformationManager.def_static("addTransformation", [](const camitk::FrameOfReference * from, const camitk::FrameOfReference * to) {
        std::shared_ptr<camitk::Transformation> newTransformation = camitk::TransformationManager::addTransformation(from, to);
        return newTransformation.get();
    },
    py::return_value_policy::reference,
    DOC(camitk_TransformationManager_addTransformation_2));

    transformationManager.def_static("updateTransformation", [](const camitk::FrameOfReference * from, const camitk::FrameOfReference * to, vtkSmartPointer<vtkTransform> vtkTr) {
        camitk::TransformationManager::updateTransformation(from, to, vtkTr);
    },
    DOC(camitk_TransformationManager_updateTransformation_1));

    transformationManager.def_static("updateTransformation", [](const camitk::FrameOfReference * from, const camitk::FrameOfReference * to, py::array numpyMatrix) {
        // Call the real function
        camitk::TransformationManager::updateTransformation(from, to, camitk::numpyToVtkTransform(numpyMatrix));
    },
    QString("%1%2").arg(DOC(camitk_TransformationManager_updateTransformation_2)).arg("\nNote that the vtkMatrix must is given as a 4x4 numpy array.").toStdString().c_str());

    transformationManager.def_static("updateTransformation", [](camitk::Transformation * tr, vtkSmartPointer<vtkTransform> vtkTr) {
        camitk::TransformationManager::updateTransformation(tr, vtkTr);
    },
    DOC(camitk_TransformationManager_updateTransformation_3));

    transformationManager.def_static("updateTransformation", [](camitk::Transformation * tr, py::array numpyMatrix) {
        camitk::TransformationManager::updateTransformation(tr, camitk::numpyToVtkTransform(numpyMatrix));
    },
    QString("%1%2").arg(DOC(camitk_TransformationManager_updateTransformation_4)).arg("\nNote that the vtkMatrix must is given as a 4x4 numpy array.").toStdString().c_str());

    transformationManager.def_static("getTransformation", &camitk::TransformationManager::getTransformation,
                                     py::return_value_policy::reference,
                                     DOC(camitk_TransformationManager_getTransformation));

    transformationManager.def_static("removeTransformation", [](camitk::Transformation * trPtr) {
        std::shared_ptr<camitk::Transformation> trSharePtr = camitk::TransformationManager::getTransformationOwnership(trPtr);
        return camitk::TransformationManager::removeTransformation(trSharePtr);
    },
    DOC(camitk_TransformationManager_removeTransformation_1));

    transformationManager.def_static("removeTransformation", [](const camitk::FrameOfReference * from, const camitk::FrameOfReference * to) {
        return camitk::TransformationManager::removeTransformation(from, to);
    },
    DOC(camitk_TransformationManager_removeTransformation_2));

    transformationManager.def_static("preferredDefaultIdentityToWorldLink", &camitk::TransformationManager::preferredDefaultIdentityToWorldLink,
                                     DOC(camitk_TransformationManager_preferredDefaultIdentityToWorldLink));

    // --------------- FrameOfReference ---------------
    py::class_<camitk::FrameOfReference> frameOfReference(m, "FrameOfReference", DOC(camitk_FrameOfReference));

    frameOfReference.def("getName", &camitk::FrameOfReference::getName,
                         DOC(camitk_FrameOfReference_getName));

    frameOfReference.def("getUuid", &camitk::FrameOfReference::getUuid,
                         DOC(camitk_FrameOfReference_getUuid));

    // --------------- Transformation ---------------
    py::class_<camitk::Transformation> transformation(m, "Transformation", DOC(camitk_Transformation));

    transformation.def("getName", &camitk::Transformation::getName,
                       DOC(camitk_Transformation_getName));

    transformation.def("getTransform", &camitk::Transformation::getTransform,
                       py::return_value_policy::reference,
                       DOC(camitk_Transformation_getTransform));

    transformation.def("getFrom", &camitk::Transformation::getFrom, py::return_value_policy::reference,
                       DOC(camitk_Transformation_getFrom));

    transformation.def("getTo", &camitk::Transformation::getTo, py::return_value_policy::reference,
                       DOC(camitk_Transformation_getTo));

}

