import {keyToString} from '../runs';
import {Key} from '../runTypes';
import {avg, median, stddev, stderr} from './math';
import {type AggregateCalculation, type Line, type Point} from './types';

// this map incrementally tracks the union values
type XValMap = Map<number, number[]>;
type Bucket = [number, number[]];

export const unionLines = (lines: Line[]): Bucket[] => {
  const xValMap: XValMap = new Map();
  lines.forEach(line => {
    line.data.forEach(point => {
      const existing = xValMap.get(point.x);
      if (existing) {
        existing.push(point.y);
      } else {
        xValMap.set(point.x, [point.y]);
      }
    });
  });

  const sortedResult = Array.from(xValMap.entries()).sort(
    (a, b) => a[0] - b[0]
  );

  return sortedResult;
};

export const calcMinMax = ([key, vals]: Bucket) => {
  return {
    x: key,
    y0: Math.min(...vals),
    y: Math.max(...vals),
  };
};
export const calcStddev = (
  [key, vals]: Bucket,
  midPointFn: (vals: number[]) => number,
  cb: (stddev: number) => void
) => {
  const mid = midPointFn(vals);
  const stddevVal = stddev(vals, mid);
  const stddevValOrZero = isNaN(stddevVal) ? 0 : stddevVal;
  cb(stddevValOrZero);
  return {
    x: key,
    y0: mid - stddevValOrZero,
    y: mid + stddevValOrZero,
  };
};

const calcStdErr = (
  [key, vals]: Bucket,
  midPointFn: (vals: number[]) => number
) => {
  const mid = midPointFn(vals);
  const stderrVal = stderr(vals);
  const stderrValOrZero = isNaN(stderrVal) ? 0 : stderrVal;

  return {
    x: key,
    y0: mid - stderrValOrZero,
    y: mid + stderrValOrZero,
  };
};

type AggregateBucket = (bucket: [index: number, vals: number[]]) => Point;

export const calcMin: AggregateBucket = ([key, vals]) => ({
  x: key,
  y: Math.min(...vals),
});

export const calcMax: AggregateBucket = ([key, vals]) => ({
  x: key,
  y: Math.max(...vals),
});

export const calcMedian: AggregateBucket = ([key, vals]) => ({
  x: key,
  y: median(vals),
});

export const calcAvg: AggregateBucket = ([key, vals]) => ({
  x: key,
  y: avg(vals),
});

const pointToAgg: Record<string, AggregateBucket> = {
  mean: calcAvg,
  median: calcMedian,
  min: calcMin,
  max: calcMax,
};

const makeRun = (lines: Line[], name: string) => ({
  title: '',
  run: lines[0].run,
  name,
  vars: {}, // check this
});

export const packageUnionLines = ({
  lines,
  name,
  line = 'mean',
  calculate,
  extraVars = [],
}: {
  lines: Line[];
  name: string;
  calculate: Record<AggregateCalculation, boolean>;
  line: string;
  extraVars: Key[];
}) => {
  const result = unionLines(lines);

  const midPoint = line === 'mean' ? avg : median;

  const aggregateExtraVars: {[key: string]: number} = {};
  extraVars.forEach(varKey => {
    const vals = lines
      .map(line => (line.vars != null ? line.vars[keyToString(varKey)] : null))
      .filter(y => y != null && isFinite(y)) as number[];
    aggregateExtraVars[keyToString(varKey)] = avg(vals);
  });

  let stddev;
  let stdDevData;
  if (calculate.stddev) {
    stdDevData = result.map(r => calcStddev(r, midPoint, s => (stddev = s)));
  }
  const toReturn = {
    minmaxLine: calculate.minmax
      ? {
          ...makeRun(lines, name),
          aggType: 'minmax',
          aux: true,
          data: result.map(calcMinMax),
          displayName: name,
          meta: {
            aggregation: 'minmax',
            category: 'grouped',
            excludeOutliers: 'include-outliers',
            mode: 'sampled',
            type: 'area',
          },
          type: 'area',
        }
      : null,
    stderrLine: calculate.stderr
      ? {
          ...makeRun(lines, name),
          aggType: 'stderr',
          aux: true,
          data: result.map(r => calcStdErr(r, midPoint)),
          meta: {
            aggregation: 'stderr',
            category: 'grouped',
            excludeOutliers: 'include-outliers',
            mode: 'sampled',
            type: 'area',
          },
          type: 'area',
          displayName: name,
        }
      : null,
    stddevLine: calculate.stddev
      ? {
          ...makeRun(lines, name),
          aggType: 'stddev',
          aux: true,
          data: stdDevData,
          meta: {
            aggregation: 'stddev',
            category: 'grouped',
            mode: 'sampled',
            minMaxOnHover: 'unused',
            type: 'area',
          },
          stddev,
          type: 'area',
          displayName: name,
        }
      : null,
    sampleLines: calculate.samples
      ? lines.map(l => {
          return {
            ...l,
            aux: true,
            mark: 'dotted',
            meta: {
              aggregation: 'samples',
              category: 'grouped',
              mode: 'sampled',
              minMaxOnHover: 'unused',
              type: 'area',
            },
          };
        })
      : null,
    meanLine: {
      ...makeRun(lines, name),
      aggType: line,
      data: result.map(pointToAgg[line]),
      meta: {
        aggregation: 'avg',
        category: 'grouped',
        excludeOutliers: 'include-outliers',
        mode: 'sampled',
        type: 'area',
      },
      vars: aggregateExtraVars,
    },
  } as Record<string, Line | null>;
  return toReturn;
};
