// SPDX-License-Identifier: MIT
// Copyright contributors to the kepler.gl project

import {Vector, Float64} from 'apache-arrow';
import {DateTime} from 'luxon';

import {ObjectInfo} from '@deck.gl/core/lib/layer';
import {ScatterplotLayer as DeckGLScatterplotLayer} from '@deck.gl/layers/typed';

import {ColorRange} from '@kepler.gl/constants';
import {Merge, VisConfigColorRange, VisConfigNumber, VisConfigRange} from '@kepler.gl/types';
import {default as KeplerTable} from '@kepler.gl/table';
import {
  ArrowDataContainer,
  DataContainerInterface,
  buildOneDayFilter,
  bufferFromAccessor
} from '@kepler.gl/utils';
import {PeriodicTripsLayer as DeckGLPeriodicTripsLayer} from '@kepler.gl/deckgl-layers';

import Layer, {LayerBaseConfig, LayerColumn} from '../base-layer';
import TripLayerIcon from '../trip-layer/trip-layer-icon';
import TripInfoModalFactory from '../trip-layer/trip-info-modal';

const EPOCH: DateTime<true> = DateTime.fromMillis(0) as DateTime<true>;
const LAST_POINT_RADIUS: number = 5;
const LAST_POINT_STROKE: number = 1.5;

export type ArrowTripLayerColumnsConfig = {
  geometry: LayerColumn;
  turtle_idx: LayerColumn;
  date: LayerColumn;
};

export type ArrowTripLayerMeta = {
  visibilityBuffer: Uint8ClampedArray;
};

type DeckGLTripLayerData = {
  length: number;
  startIndices: Uint16Array;
  attributes: {
    getPath: {value: Float64Array; size: 2};
    getTimestamps: {value: Float32Array; size: 1};
    getFilterValue: {value: Float32Array; size: number};
  };
};
export type ArrowTripLayerData = {
  animationConfig: any;
  data: DeckGLTripLayerData;
};

export type ArrowTripLayerVisConfigSettings = {
  opacity: VisConfigNumber;
  thickness: VisConfigNumber;
  colorRange: VisConfigColorRange;
  trailLength: VisConfigNumber;
  sizeRange: VisConfigRange;
};

export type ArrowTripLayerVisConfig = {
  opacity: number;
  thickness: number;
  colorRange: ColorRange;
  trailLength: number;
  sizeRange: [number, number];
};

export const defaultThickness = 0.5;
export const defaultTrailLengthDay = 3;
export const arrowTripVisConfigs: {
  opacity: 'opacity';
  thickness: VisConfigNumber;
  colorRange: 'colorRange';
  trailLength: VisConfigNumber;
  sizeRange: 'strokeWidthRange';
} = {
  opacity: 'opacity',
  thickness: {
    type: 'number',
    defaultValue: defaultThickness,
    label: 'Stroke Width',
    isRanged: false,
    range: [0, 100],
    step: 0.1,
    group: 'stroke',
    property: 'thickness'
  },
  colorRange: 'colorRange',
  trailLength: {
    type: 'number',
    defaultValue: defaultTrailLengthDay,
    label: 'Trail Length',
    isRanged: false,
    range: [0, 365],
    step: 1.0,
    group: 'trail',
    property: 'trailLength'
  },
  sizeRange: 'strokeWidthRange'
};

export type ArrowTripLayerConfig = Merge<
  LayerBaseConfig,
  {columns: ArrowTripLayerColumnsConfig; visConfig: ArrowTripLayerVisConfig}
>;

const zoomFactorValue = 4;

export const defaultLineWidth = 5;

export const requiredColumns: ['geometry', 'turtle_idx', 'date'] = [
  'geometry',
  'turtle_idx',
  'date'
];

export const featureResolver = ({geometry, turtle_idx, date}: ArrowTripLayerColumnsConfig) => [
  geometry.fieldIdx,
  turtle_idx.fieldIdx,
  date.fieldIdx
];

const getIndexAccessor = (
  _dc: DataContainerInterface,
  _d: any,
  objectInfo?: ObjectInfo<any, any>
) => {
  if (objectInfo === undefined) {
    throw new Error('objectInfo is undefined');
  }

  const valIdx = objectInfo.index;
  return valIdx;
};
const getDataAccessor = (
  dc: DataContainerInterface,
  _d: any,
  colIdx: number,
  objectInfo?: ObjectInfo<any, any>
) => {
  if (objectInfo === undefined) {
    throw new Error('objectInfo is undefined');
  }

  const valIdx = objectInfo.index;
  const val = dc.valueAt(valIdx, colIdx);
  return val;
};

/**
 * > NOTE: Remove the commented part that refers to the "visibilityBuffer" to re-enable CPU filtering.
 * > The function `updateVisibilityBuffer` should build a visibilityBuffer and update the ref in `layer.meta`.
 * > See commit `93561a9776e20b3cb2a0cafadc19478b3d46ea83` to see the previous implementation.
 */
export default class ArrowTripLayer extends Layer {
  declare visConfigSettings: ArrowTripLayerVisConfigSettings;
  declare config: ArrowTripLayerConfig;
  declare meta: ArrowTripLayerMeta;

  filteredIndexTrigger: number[] | null = null;

  _layerInfoModal: () => JSX.Element;

  constructor(props: any) {
    super(props);

    this.meta = {
      visibilityBuffer: new Uint8ClampedArray()
    };
    this.registerVisConfig(arrowTripVisConfigs);
    this._layerInfoModal = TripInfoModalFactory();
  }

  static get type(): 'arrowTripLayer' {
    return 'arrowTripLayer';
  }

  override get type() {
    return ArrowTripLayer.type;
  }

  override get name(): 'ArrowTrip' {
    return 'ArrowTrip';
  }

  override get layerIcon() {
    return TripLayerIcon;
  }

  override get columnPairs() {
    return this.defaultPointColumnPairs;
  }

  override get requiredLayerColumns() {
    return requiredColumns;
  }

  override get visualChannels() {
    const {color, size, ...baseVisualChannels} = super.visualChannels;

    return {
      ...baseVisualChannels,
      color: {
        ...color,
        accessor: 'getColor',
        getAttributeValue: (config: any) => config.color,
        // used this to get updateTriggers
        defaultValue: (config: any) => config.color
      },
      size: {
        ...size,
        property: 'stroke',
        accessor: 'getWidth',
        condition: (config: any) => config.visConfig.stroked,
        nullValue: 0,
        getAttributeValue: () => (_: any) => defaultLineWidth
      }
    };
  }

  get animationDomain() {
    return this.config.animation.domain;
  }

  override get layerInfoModal() {
    return {
      id: 'iconInfo',
      template: this._layerInfoModal,
      modalProps: {
        title: 'modal.tripInfo.title'
      }
    };
  }

  override getPositionAccessor(
    dataContainer: DataContainerInterface
  ): (data: any, {index}: {index: number}) => number[] {
    const accessor = (_: any, {index}: {index: number}) => {
      if (!(dataContainer instanceof ArrowDataContainer)) {
        throw new Error('Only accept arrow data container');
      }

      const colGeomIdx = this.config.columns.geometry.fieldIdx;
      const posVector = dataContainer.getColumn(colGeomIdx).get(index) as Vector<Float64>;
      const position = [posVector.get(0), posVector.get(1)];
      if (position[0] === null || position[1] === null) {
        throw new Error('Malformed position data, lon or lat is null');
      }

      return position as [number, number];
    };

    return accessor;
  }

  static findDefaultLayerProps(
    {label, fields = [], dataContainer, id}: KeplerTable,
    foundLayers: any[]
  ) {
    const defaultColumns = {
      geometry: ['geometry'],
      turtle_idx: ['turtle_idx'],
      date: ['date']
    };

    const layerColumns = this.findDefaultColumnField(defaultColumns, fields);

    if (layerColumns) {
      if (layerColumns.length) {
        return {
          props: layerColumns.map((columns) => ({
            label: (typeof label === 'string' && label.replace(/\.[^/.]+$/, '')) || this.type,
            columns,
            isVisible: true
          })),
          foundLayers: foundLayers
        };
      }
    }

    return {props: []};
  }

  override getDefaultLayerConfig(props: any) {
    return {
      ...super.getDefaultLayerConfig(props),
      animation: {
        enabled: true,
        domain: null
      }
    };
  }

  override getHoverData(object: any, dataContainer: any) {
    return dataContainer.getColumn(object.turtle_idx);
  }

  private buildPositionBuffer(dataContainer: ArrowDataContainer): Float64Array {
    const colGeomIdx = this.config.columns.geometry.fieldIdx;

    const positionsCol = dataContainer.getColumn(colGeomIdx);
    const positionsBuf = new Float64Array(positionsCol.length * 2);
    for (let i = 0; i < positionsCol.length; ++i) {
      const position: Vector<Float64> = dataContainer.getColumn(colGeomIdx).get(i);
      const bufferIdx = i * 2;
      positionsBuf[bufferIdx] = position.get(0)!;
      positionsBuf[bufferIdx + 1] = position.get(1)!;
    }

    return positionsBuf;
  }

  private buildTimestampsBuffer(dataContainer: ArrowDataContainer): Float32Array {
    const colDateIdx = this.config.columns.date.fieldIdx;
    const datesCol = dataContainer.getColumn(colDateIdx);

    const timestampsBuf = new Float32Array(datesCol.length);
    for (let i = 0; i < datesCol.length; ++i) {
      const date = DateTime.fromJSDate(datesCol.get(i));
      const dayDiff = Math.floor(date.diff(EPOCH, 'day').as('day'));
      timestampsBuf[i] = dayDiff;
    }

    return timestampsBuf;
  }

  private buildStartIndicesBuffer(dataContainer: ArrowDataContainer): Uint16Array {
    const colTurtleIdx = this.config.columns.turtle_idx.fieldIdx;

    const turtleIdxes = dataContainer.getColumn(colTurtleIdx).toArray() as Uint32Array;
    const startIndices = turtleIdxes.reduce(
      ({startIndices, lastTurtleIdx}, turtleIdx, arrayIdx) => {
        if (turtleIdx !== lastTurtleIdx) {
          startIndices.push(arrayIdx);
          lastTurtleIdx = turtleIdx;
        }

        return {startIndices, lastTurtleIdx};
      },
      {startIndices: new Array<number>(), lastTurtleIdx: -1}
    ).startIndices;
    const startIndicesBuf = Uint16Array.from(startIndices);

    return startIndicesBuf;
  }

  override calculateDataAttribute(keplerTable: KeplerTable): DeckGLTripLayerData | {} {
    const {
      dataContainer,
      // filteredIndex,
      gpuFilter
    }: {dataContainer: DataContainerInterface; filteredIndex: any; gpuFilter: any} = keplerTable;

    if (!(dataContainer instanceof ArrowDataContainer)) {
      console.error('Type other than ArrowDataContainer are not supported');
      return {};
    }

    // this.updateVisibilityBuffer(dataContainer, filteredIndex);

    const isDataEmpty = dataContainer.numRows() === 0;
    let positionsBuf: Float64Array;
    let timestampsBuf: Float32Array;
    let startIndicesBuf: Uint16Array;
    let nbPaths: number;
    if (!isDataEmpty) {
      // Build positions buffer
      positionsBuf = this.buildPositionBuffer(dataContainer);
      // Build timestamps buffer as day offset from EPOCH
      timestampsBuf = this.buildTimestampsBuffer(dataContainer);
      // Build start indices buffer
      startIndicesBuf = this.buildStartIndicesBuffer(dataContainer);

      nbPaths = startIndicesBuf.length;
    } else {
      positionsBuf = new Float64Array();
      timestampsBuf = new Float32Array();
      startIndicesBuf = new Uint16Array();
      nbPaths = 0;
    }

    //const filteredBuf = Float32Array.from(this.meta.visibilityBuffer);

    const layerData = {
      length: nbPaths,
      startIndices: startIndicesBuf,
      attributes: {
        getPath: {
          value: positionsBuf,
          size: 2
        } as DeckGLTripLayerData['attributes']['getPath'],
        getTimestamps: {
          value: timestampsBuf,
          size: 1
        } as DeckGLTripLayerData['attributes']['getTimestamps']
        //getFiltered: {value: filteredBuf, size: 1}
      }
    };

    // Adding filtering data
    // Build filtering data buffer
    const nbFilters: number = gpuFilter.filterRange.length;
    const accessor = gpuFilter.filterValueAccessor(dataContainer)(
      getIndexAccessor,
      getDataAccessor
    );
    const filterValuesBuf = bufferFromAccessor(
      dataContainer,
      layerData,
      accessor,
      Float32Array,
      nbFilters
    );

    const layerDataWFilter = {
      ...layerData,
      attributes: {
        ...layerData.attributes,
        getFilterValue: {value: filterValuesBuf, size: nbFilters}
      }
    } as DeckGLTripLayerData;

    return layerDataWFilter;
  }

  override formatLayerData(datasets: any, oldLayerData: any): ArrowTripLayerData | {} {
    if (this.config.dataId === null) {
      return {};
    }

    const data = super.updateData(datasets, oldLayerData);

    const {dataContainer, gpuFilter}: {dataContainer: ArrowDataContainer; gpuFilter: any} =
      datasets[this.config.dataId];

    const accessors = super.getAttributeAccessors({dataContainer});

    return {
      ...data,
      getFilterValue: gpuFilter.filterValueAccessor(dataContainer)(
        getIndexAccessor,
        getDataAccessor
      ),
      ...accessors
    };
  }

  updateAnimationDomain(domain: [number, number]): void {
    const anyDomain = domain as any;
    this.updateLayerConfig({
      animation: {
        ...this.config.animation,
        domain: anyDomain
      }
    });
  }

  override updateLayerMeta(dataContainer: ArrowDataContainer) {
    let animationDomain: [number, number];
    const isDataEmpty = dataContainer.numRows() === 0;
    if (!isDataEmpty) {
      const colDateIdx = this.config.columns.date.fieldIdx;
      const dateCol = dataContainer.getColumn(colDateIdx);

      const arrowDates = dateCol.toArray() as Array<Date>;
      animationDomain = arrowDates.reduce(
        (animationDomain, arrowDate) => {
          const timestamp = arrowDate.getTime();
          if (timestamp < animationDomain[0]) {
            animationDomain[0] = timestamp;
          }
          if (timestamp > animationDomain[1]) {
            animationDomain[1] = timestamp;
          }
          return animationDomain;
        },
        [Infinity, -Infinity]
      );
    } else {
      animationDomain = [0, 0];
    }

    this.updateAnimationDomain(animationDomain as [number, number]);

    const getPosition = ({index}: {index: number}, dc: DataContainerInterface) =>
      this.getPositionAccessor(dc)(undefined, {index});
    const bounds = this.getPointsBounds(dataContainer, getPosition);
    this.updateMeta({
      bounds
    });
  }

  setInitialLayerConfig({dataContainer}) {
    this.updateLayerMeta(dataContainer);
    return this;
  }

  override renderLayer(opts: any) {
    if (Object.keys(opts).length === 0) {
      console.error('Formated data is empty');
      return [];
    }

    const {animationConfig, data, gpuFilter, mapState} = opts;
    const {visConfig} = this.config;
    const zoomFactor = this.getZoomFactor(mapState);

    const isValidTime: boolean =
      animationConfig &&
      Array.isArray(animationConfig.domain) &&
      animationConfig.domain.every(Number.isFinite) &&
      Number.isFinite(animationConfig.currentTime);

    if (!isValidTime) {
      console.error('Time is not valid');
      return [];
    }

    const domain0 = animationConfig.domain?.[0];

    const updateTriggers = {
      ...this.getVisualChannelUpdateTriggers(),
      getTimestamps: {
        columns: {
          date: this.config.columns.date
        },
        domain0
      },
      getFilterValue: gpuFilter.filterValueUpdateTriggers
      //getFilter: this.filteredIndexTrigger
    };
    const defaultLayerProps = this.getDefaultDeckLayerProps(opts as any);

    // Time in days
    const trailLength = visConfig.trailLength ?? defaultTrailLengthDay;
    const currentTime = DateTime.fromMillis(animationConfig.currentTime)
      .diff(EPOCH, 'day')
      .as('day');

    const mainLayer = new DeckGLPeriodicTripsLayer({
      ...defaultLayerProps,
      ...data,
      id: `${defaultLayerProps.id}-trip`,
      widthScale: this.config.visConfig.thickness * zoomFactor * zoomFactorValue,
      capRounded: true,
      jointRounded: true,
      wrapLongitude: false,
      parameters: {
        depthTest: mapState.dragRotate,
        depthMask: false
      },
      trailLength,
      period: trailLength,
      currentTime,
      updateTriggers,
      positionFormat: 'XY',
      _pathType: 'open'
    });
    const layerData: DeckGLTripLayerData = data.data;
    const oneDayFilter = defaultLayerProps.filterRange
      ? buildOneDayFilter(defaultLayerProps.filterRange)
      : undefined;
    const lastPointLayer = new DeckGLScatterplotLayer({
      ...defaultLayerProps,
      ...data,
      id: `${defaultLayerProps}-point`,
      data: {
        length: layerData.attributes.getPath.value.length / layerData.attributes.getPath.size,
        attributes: {
          getPosition: layerData.attributes.getPath,
          getFilterValue: layerData.attributes.getFilterValue
        }
      },
      updateTriggers,
      positionFormat: 'XY',
      filterRange: oneDayFilter,

      // Fill color
      getColor: undefined,
      getFillColor: data.getColor,
      getRadius: LAST_POINT_RADIUS * zoomFactor * zoomFactorValue,

      // Stroke
      stroked: true,
      getLineColor: [255, 55, 55],
      getLineWidth: LAST_POINT_STROKE,
      lineWidthScale: zoomFactor * zoomFactorValue
    });

    return [mainLayer, lastPointLayer];
  }
}
