import { useBaseXyz } from '@local/webviz/dist/context/hooks/useBaseXyz';
import { SurfaceElementState } from '@local/webviz/dist/types';
import { useCallback } from 'react';

import { GtmMeshDetectorAction } from 'src/apiClients/gtmCompute/gtmComputeApi.types';

import { DefectClasses, generateDefectsSnapshotState } from './defectsFactory';
import { generateHighlightSnapshot, HighlightClasses } from './highlightFactory';
import { defaultGtmMeshDetectionData, DefectData } from './types';

export function useDefectsVisualizationManager({ objectId, defects }: DefectData) {
    const {
        getEntityState,
        setStateFromSnapshot,
        addViewToPlotDirectly,
        getState,
        removeViewsFromPlotDirectly,
    } = useBaseXyz();

    function removeDefectedObjects() {
        const defectedObjects =
            Object.keys(getState()).filter((key) => key.includes('defect')) ?? [];
        removeViewsFromPlotDirectly([...defectedObjects]);
    }

    function removeHighlightedObjects() {
        const defectedObjects =
            Object.keys(getState()).filter((key) => key.includes('highlight')) ?? [];
        removeViewsFromPlotDirectly([...defectedObjects]);
    }

    const renderHighlights = useCallback(
        (action: GtmMeshDetectorAction, defectIndex: number) => {
            const surfaceElement = getEntityState(`${objectId}-element`) as SurfaceElementState;
            if (surfaceElement && 'triangles' in surfaceElement && 'vertices' in surfaceElement) {
                removeHighlightedObjects();
                if (action && defects && action in defects && defects[action]?.defects) {
                    const snapshotMap = generateHighlightSnapshot(action as HighlightClasses, {
                        objectId,
                        defectIndex,
                        surface: surfaceElement,
                        defects: defects[action]?.defects ?? defaultGtmMeshDetectionData,
                    });
                    if (snapshotMap) {
                        const viewIds = Object.keys(snapshotMap);
                        viewIds.forEach((viewId) => {
                            if (snapshotMap[viewId]) {
                                setStateFromSnapshot(snapshotMap[viewId], {});
                                addViewToPlotDirectly(viewId);
                            }
                        });
                    }
                }
            }
        },
        [objectId, defects],
    );

    const renderDefects = useCallback(() => {
        const surfaceElement = getEntityState(`${objectId}-element`) as SurfaceElementState;
        if (surfaceElement && 'triangles' in surfaceElement && 'vertices' in surfaceElement) {
            removeDefectedObjects();
            if (defects) {
                const detectorActions = Object.keys(defects) as DefectClasses[];
                detectorActions.forEach((detectorAction) => {
                    if (detectorAction in defects && defects[detectorAction]?.defects) {
                        const snapshotMap = generateDefectsSnapshotState(detectorAction, {
                            defects:
                                defects[detectorAction]?.defects ?? defaultGtmMeshDetectionData,
                            surface: surfaceElement,
                            objectId,
                        });
                        if (snapshotMap) {
                            const viewIds = Object.keys(snapshotMap);
                            viewIds.forEach((viewId) => {
                                if (snapshotMap[viewId]) {
                                    setStateFromSnapshot(snapshotMap[viewId], {});
                                    addViewToPlotDirectly(viewId);
                                }
                            });
                        }
                    }
                });
            }
        }
    }, [objectId, defects]);

    return {
        renderDefects,
        renderHighlights,
    };
}
