// SPDX-License-Identifier: MIT
// Copyright contributors to the kepler.gl project

import {BrushingExtension} from '@deck.gl/extensions';
import {ScatterplotLayer} from '@deck.gl/layers';
import {Vector as ArrowVector} from 'apache-arrow';

import {ArrowDataContainer, buildOneDayFilter, findDefaultColorField} from '@kepler.gl/utils';
import {default as KeplerTable} from '@kepler.gl/table';
import {LAYER_VIS_CONFIGS, CHANNEL_SCALES, ColorRange, PROPERTY_GROUPS} from '@kepler.gl/constants';
import {
  Field,
  Merge,
  RGBColor,
  VisConfigBoolean,
  VisConfigColorRange,
  VisConfigColorSelect,
  VisConfigNumber,
  VisConfigRange
} from '@kepler.gl/types';

import Layer, {
  LayerBaseConfig,
  LayerBaseConfigPartial,
  LayerColorConfig,
  LayerColumn,
  LayerSizeConfig,
  LayerStrokeColorConfig
} from '../base-layer';
import PointLayerIcon from '../point-layer/point-layer-icon';
import {formatTextLabelData} from '../layer-text-label';

const LAST_POINT_RADIUS: number = 2;
const LAST_POINT_STROKE: number = 5;

export type CustomScatterplotLayerVisConfigSettings = {
  radius: VisConfigNumber;
  fixedRadius: VisConfigBoolean;
  opacity: VisConfigNumber;
  outline: VisConfigBoolean;
  thickness: VisConfigNumber;
  strokeColor: VisConfigColorSelect;
  colorRange: VisConfigColorRange;
  strokeColorRange: VisConfigColorRange;
  radiusRange: VisConfigRange;
  filled: VisConfigBoolean;
  individualColor: VisConfigBoolean;
};

export type CustomScatterplotLayerColumnsConfig = {
  geometry: LayerColumn;
  turtle_idx: LayerColumn;
  date: LayerColumn;
};

export type CustomScatterplotLayerVisConfig = {
  radius: number;
  fixedRadius: boolean;
  opacity: number;
  outline: boolean;
  thickness: number;
  strokeColor: RGBColor;
  colorRange: ColorRange;
  strokeColorRange: ColorRange;
  radiusRange: [number, number];
  filled: boolean;
  individualColor: boolean;
};
export type CustomScatterplotLayerVisualChannelConfig = LayerColorConfig &
  LayerSizeConfig &
  LayerStrokeColorConfig;
export type CustomScatterplotLayerConfig = Merge<
  LayerBaseConfig,
  {columns: CustomScatterplotLayerColumnsConfig; visConfig: CustomScatterplotLayerVisConfig}
> &
  CustomScatterplotLayerVisualChannelConfig;

export type CustomScatterplotLayerData = {
  position: number[];
  index: number;
};

export const customScatterplotPosAccessor =
  ({geometry}: CustomScatterplotLayerColumnsConfig) =>
  (dc: ArrowDataContainer) =>
  (d: CustomScatterplotLayerData) => {
    const pos = dc.valueAt(d.index, geometry.fieldIdx) as ArrowVector<any>;
    return [pos.get(0), pos.get(1)];
  };

export const customScatterplotRequiredColumns: ['geometry', 'turtle_idx', 'date'] = [
  'geometry',
  'turtle_idx',
  'date'
];

const brushingExtension = new BrushingExtension();

export const customScatterplotVisConfigs: {
  radius: 'radius';
  fixedRadius: 'fixedRadius';
  opacity: 'opacity';
  outline: 'outline';
  thickness: 'thickness';
  strokeColor: 'strokeColor';
  colorRange: 'colorRange';
  strokeColorRange: 'strokeColorRange';
  radiusRange: 'radiusRange';
  filled: VisConfigBoolean;
  individualColor: VisConfigBoolean;
} = {
  radius: 'radius',
  fixedRadius: 'fixedRadius',
  opacity: 'opacity',
  outline: 'outline',
  thickness: 'thickness',
  strokeColor: 'strokeColor',
  colorRange: 'colorRange',
  strokeColorRange: 'strokeColorRange',
  radiusRange: 'radiusRange',
  filled: {
    ...LAYER_VIS_CONFIGS.filled,
    type: 'boolean',
    label: 'layer.fillColor',
    defaultValue: true,
    property: 'filled'
  },
  individualColor: {
    type: 'boolean',
    label: 'color mode',
    defaultValue: false,
    property: 'individualColor',
    group: PROPERTY_GROUPS.display
  }
};

export default class CustomScatterplotLayer extends Layer {
  declare config: CustomScatterplotLayerConfig;
  declare visConfigSettings: CustomScatterplotLayerVisConfigSettings;

  defaultColorField: Field | undefined;

  constructor(props) {
    super(props);

    this.registerVisConfig(customScatterplotVisConfigs);
    this.getPositionAccessor = (dataContainer) =>
      customScatterplotPosAccessor(this.config.columns)(dataContainer);
  }

  getPositionAccessor: (
    dataContainer: ArrowDataContainer
  ) => (d: CustomScatterplotLayerData) => any;

  static get type(): 'customScatterplot' {
    return 'customScatterplot';
  }

  override get name(): 'customScatterplot' {
    return 'customScatterplot';
  }

  override get type(): 'customScatterplot' {
    return CustomScatterplotLayer.type;
  }

  get isAggregated(): false {
    return false;
  }

  get layerIcon() {
    return PointLayerIcon;
  }
  get requiredLayerColumns() {
    return customScatterplotRequiredColumns;
  }

  get columnPairs() {
    return this.defaultPointColumnPairs;
  }

  get noneLayerDataAffectingProps() {
    return [...super.noneLayerDataAffectingProps, 'radius'];
  }

  get visualChannels() {
    return {
      color: {
        ...super.visualChannels.color,
        accessor: 'getFillColor',
        condition: (config) => config.visConfig.filled,
        defaultValue: (config) => config.color
      },
      strokeColor: {
        property: 'strokeColor',
        key: 'strokeColor',
        field: 'strokeColorField',
        scale: 'strokeColorScale',
        domain: 'strokeColorDomain',
        range: 'strokeColorRange',
        channelScaleType: CHANNEL_SCALES.color,
        accessor: 'getLineColor',
        condition: (config) => config.visConfig.outline,
        defaultValue: (config) => config.visConfig.strokeColor || config.color
      },
      size: {
        ...super.visualChannels.size,
        property: 'radius',
        range: 'radiusRange',
        fixed: 'fixedRadius',
        channelScaleType: 'radius',
        accessor: 'getRadius',
        defaultValue: 1
      }
    };
  }

  // @ts-ignore
  override updateLayerConfig(newConfig: Partial<CustomScatterplotLayerConfig>): this {
    const individualColor: boolean | undefined = newConfig.visConfig?.individualColor;

    if (individualColor === undefined) {
      return super.updateLayerConfig(newConfig);
    }
    if (this.defaultColorField === undefined) {
      let visConfig = newConfig.visConfig!; // `individualColor` is defined => `visConfig` is defined
      visConfig = {
        ...visConfig,
        individualColor: false
      };
      newConfig = {
        ...newConfig,
        visConfig
      };

      return super.updateLayerConfig(newConfig);
    }

    if (individualColor) {
      let visConfig = newConfig.visConfig!;
      visConfig = {
        ...this.config.visConfig,
        ...visConfig
      };
      newConfig = {
        ...newConfig,
        visConfig,
        colorField: this.defaultColorField
      };
    } else {
      let visConfig = newConfig.visConfig!;
      visConfig = {
        ...this.config.visConfig,
        ...visConfig
      };
      newConfig = {
        ...newConfig,
        visConfig,
        colorField: undefined
      };
    }

    return super.updateLayerConfig(newConfig);
  }

  setInitialLayerConfig(dataset) {
    if (!dataset.dataContainer.numRows()) {
      return this;
    }
    const defaultColorField = findDefaultColorField(dataset);

    if (defaultColorField) {
      this.updateLayerConfig({
        colorField: defaultColorField
      });
      this.updateLayerVisualChannel(dataset, 'color');
    }

    return this;
  }

  getDefaultLayerConfig(props: LayerBaseConfigPartial) {
    return {
      ...super.getDefaultLayerConfig(props),

      // add stroke color visual channel
      strokeColorField: null,
      strokeColorDomain: [0, 1],
      strokeColorScale: 'quantile'
    };
  }

  calculateDataAttribute(dataset: KeplerTable, getPosition) {
    const {filteredIndex} = dataset;

    const data: CustomScatterplotLayerData[] = [];
    for (let i = 0; i < filteredIndex.length; i++) {
      const index = filteredIndex[i];
      const pos = getPosition({index});

      // if doesn't have point lat or lng, do not add the point
      // deck.gl can't handle position = null
      if (pos.every(Number.isFinite)) {
        data.push({
          position: pos,
          index
        });
      }
    }
    this.defaultColorField = dataset.fields.find((field) => field.name === 'turtle_idx');
    return data;
  }

  formatLayerData(datasets, oldLayerData) {
    if (this.config.dataId === null) {
      return {};
    }
    const {textLabel} = this.config;
    const {gpuFilter, dataContainer} = datasets[this.config.dataId];
    const {data, triggerChanged} = this.updateData(datasets, oldLayerData);
    const getPosition = this.getPositionAccessor(dataContainer);

    // get all distinct characters in the text labels
    const textLabels = formatTextLabelData({
      textLabel,
      triggerChanged,
      oldLayerData,
      data,
      dataContainer
    });

    const accessors = this.getAttributeAccessors({dataContainer});

    return {
      data,
      getPosition,
      getFilterValue: gpuFilter.filterValueAccessor(dataContainer)(),
      textLabels,
      ...accessors
    };
  }
  /* eslint-enable complexity */

  updateLayerMeta(dataContainer) {
    const getPosition = this.getPositionAccessor(dataContainer);
    const bounds = this.getPointsBounds(dataContainer, getPosition);
    this.updateMeta({bounds});
  }

  renderLayer(opts) {
    const {data, gpuFilter, mapState, interactionConfig} = opts;

    // if no field size is defined we need to pass fixed radius = false
    const fixedRadius = this.config.visConfig.fixedRadius && Boolean(this.config.sizeField);
    const radiusScale = this.getRadiusScaleByZoom(mapState, fixedRadius);

    const layerProps = {
      stroked: this.config.visConfig.outline,
      filled: this.config.visConfig.filled,
      lineWidthScale: this.config.visConfig.thickness,
      radiusScale,
      ...(this.config.visConfig.fixedRadius ? {} : {radiusMaxPixels: 500})
    };

    const updateTriggers = {
      getPosition: this.config.columns,
      getFilterValue: gpuFilter.filterValueUpdateTriggers,
      ...this.getVisualChannelUpdateTriggers()
    };

    const defaultLayerProps = this.getDefaultDeckLayerProps(opts);
    const brushingProps = this.getBrushingExtensionProps(interactionConfig);
    const extensions = [...defaultLayerProps.extensions, brushingExtension];
    const zoomFactor = this.getZoomFactor(mapState);

    const trajectoryLayer = new ScatterplotLayer({
      ...defaultLayerProps,
      ...brushingProps,
      ...layerProps,
      ...data,
      id: `${defaultLayerProps.id}-traj`,
      parameters: {
        // no altitude
        depthTest: false
      },
      lineWidthUnits: 'pixels',
      updateTriggers,
      extensions,
      transitions: {
        getPosition: {
          type: 'interpolation'
        }
      }
    });

    const newFilters = defaultLayerProps.filterRange
      ? buildOneDayFilter(defaultLayerProps.filterRange)
      : undefined;
    const lastPointLayer = new ScatterplotLayer({
      ...defaultLayerProps,
      ...brushingProps,
      ...layerProps,
      ...data,
      id: `${defaultLayerProps.id}-point`,
      parameters: {
        depthTest: false
      },
      updateTriggers,
      extensions,
      transitions: {
        getPosition: {
          type: 'interpolation'
        }
      },
      positionFormat: 'XY',
      filterRange: newFilters,

      // Fill color
      getColor: undefined,
      getRadius: LAST_POINT_RADIUS,

      // Stroke
      stroked: true,
      getLineColor: [255, 55, 55],
      getLineWidth: LAST_POINT_STROKE,
      lineWidthScale: zoomFactor
    });

    return [trajectoryLayer, lastPointLayer];
  }
}
