import _ from 'lodash';

import {YAxisType} from '../../components/WorkspaceDrawer/Settings/types';
import {RunsData} from '../../containers/RunsDataLoader';
import {RunWithRunsetInfo} from '../../state/runs/types';
import * as ColorUtil from '../colors';
import {
  configKeysInExpression,
  evaluateExpression,
  Expression,
  expressionToString,
  metricsInExpression,
  summaryKeysInExpression,
} from '../expr';
import {
  legendTemplateRemoveCrosshairValues,
  legendTemplateToFancyLegendProps,
  parseLegendTemplate,
} from '../legend';
import * as Run from '../runs';
import {keyToString} from '../runs';
import * as RunTypes from '../runTypes';
import {RunColorConfig} from '../section';
import {
  axisLabelMargin,
  formatYAxis,
  getScaleFnFromScaleObject,
  getYAxisWidth,
} from './axis';
import type {Chart, ChartAggOption} from './chartTypes';
import * as PlotMath from './math';
import {PlotFontSize, plotFontSizeToPx} from './plotFontSize';
import {prettifyMetricName} from './prettifyMetricName';
import {AggregateCalculation, Bar, Line, RunSetInfo, Scalar} from './types';

const DEFAULT_VIOLIN_PLOT_BINS = 10;

export function defaultRunCharts(metricNames: [string]) {
  let x = 0;
  let y = 0;
  const importantRegexes = [/loss/, /accuracy|acc$/, /n_success/];
  const charts: Chart[] = [];
  importantRegexes.forEach(regex => {
    const lines = metricNames.filter(n => n.match(regex));
    if (lines.length > 0) {
      charts.push({
        config: {lines},
        layout: {x, y, w: 6, h: 2},
      });
      x += 6;
      if (x === 12) {
        x = 0;
        y += 2;
      }
    }
  });
  if (charts.length === 1) {
    charts[0].layout.w = 12;
  }
  return charts;
}

export function filterLines(lines: Line[], xScale?: 'log' | 'linear') {
  return lines.map(line => {
    const dataWithFinitePoints = line.data.filter(
      pt => _.isFinite(pt.x) && _.isFinite(pt.y)
    );
    line.data =
      xScale === 'log'
        ? dataWithFinitePoints.filter(pt => pt.x > 0)
        : dataWithFinitePoints;
    return line;
  });
}

/**
 * I don't full understand what this is doing for log charts [needs investigation]
 */
export function filterFullFidelityLines(
  lines: Line[],
  xScale?: 'log' | 'linear'
) {
  return lines.map(line => {
    line.data =
      xScale === 'log'
        ? line.data.filter(pt => _.isFinite(pt.x) && pt.x > 0)
        : line.data;
    return line;
  });
}

export function filterNegative(lines: Line[]) {
  /**
   * Iterate over all the lines and remove non-positive values for log scale
   */
  // TODO: Check if the NaN works ok
  // TODO: This doesn't handle area graphs
  return lines.map((line, i) => {
    const newLine = line;
    newLine.data = line.data.map(point => {
      if (point.y <= 0) {
        point.y = NaN;
      }
      return point;
    });
    return newLine;
  });
}

export function isMonotonicIncreasing(arr: number[]) {
  const n = arr.length;
  let i = 1;
  while (i < n && arr[i] - arr[i - 1] >= 0) {
    i++;
  }
  return i === n;
}

interface RunDataPoint {
  run: RunWithRunsetInfo;
  value: number;
  metricName: string;
  runOrGroupUniqueId: string;
  runOrGroupDisplayName: string;
}

type GroupDataPoint = RunDataPoint & {
  groupKeys: RunTypes.Key[];
  stddev?: number;
  stderr?: number;
  mean?: number;
  quartiles?: [number, number, number, number, number];
  bins?: Array<{bin: number; count: number}>;
  range?: [number, number];
};

function isGroupDataPoint(
  dataPoint: GroupDataPoint | RunDataPoint
): dataPoint is GroupDataPoint {
  return (dataPoint as GroupDataPoint).groupKeys != null;
}

function aggregatePoints(
  points: RunDataPoint[],
  aggregateCalculations: AggregateCalculation[], // currently we do all calculations
  groupAgg: ChartAggOption,
  groupArea: AggregateCalculation,
  groupKeys: RunTypes.Key[],
  numBins?: number
): GroupDataPoint | null {
  if (points.length === 0) {
    return null;
  }
  const values = points.map(p => p.value);
  const meanVal = PlotMath.avg(values);
  const stddevVal = PlotMath.stddev(values);
  const stderrVal = PlotMath.stderr(values);
  const quartilesVal = PlotMath.quartiles(values);
  const medianVal = quartilesVal[2];

  const bins = numBins != null ? PlotMath.bin(values, numBins) : undefined;

  let runIndex = 0;
  if (groupAgg === 'min') {
    runIndex = PlotMath.argMin(values);
  } else if (groupAgg === 'max') {
    runIndex = PlotMath.argMax(values);
  }

  const value =
    groupAgg === 'mean'
      ? meanVal
      : groupAgg === 'median'
      ? medianVal
      : groupAgg === 'max'
      ? quartilesVal[4]
      : groupAgg === 'min'
      ? quartilesVal[0]
      : groupAgg === 'sum'
      ? _.sum(values)
      : 0;

  const area: [number, number] | undefined =
    groupArea === 'minmax'
      ? [quartilesVal[0], quartilesVal[4]]
      : groupArea === 'stddev'
      ? [value - stddevVal, value + stddevVal]
      : groupArea === 'stderr'
      ? [value - stderrVal, value + stderrVal]
      : undefined;

  return {
    run: points[runIndex].run, // max or min case
    metricName: points[runIndex].metricName, // max or min case
    runOrGroupUniqueId: points[runIndex].runOrGroupUniqueId,
    runOrGroupDisplayName: points[runIndex].runOrGroupDisplayName,
    value,
    mean: meanVal,
    stddev: stddevVal,
    stderr: stderrVal,
    quartiles: quartilesVal,
    range: area,
    bins,
    groupKeys,
  };
}

type PointsFromRunsetParams = {
  runs: RunWithRunsetInfo[];
  metrics: RunTypes.Key[]; // typically one metric otherwise will aggregate across metrics
  expressions?: Expression[];
  groupKeys: RunTypes.Key[];
  aggregateCalculations: AggregateCalculation[];
  legendTemplate?: string;
  groupAgg?: ChartAggOption;
  groupArea?: AggregateCalculation;
  boxPlot?: boolean;
  violinPlot?: boolean;
  mergeRunsets?: boolean;
};

function pointsFromRunset(props: PointsFromRunsetParams) {
  const {
    runs,
    metrics,
    expressions,
    groupKeys,
    aggregateCalculations,
    groupAgg,
    groupArea,
    violinPlot,
    mergeRunsets,
  } = props;
  /*
   * Converts data in runset format to point format
   * Also does aggregation
   */
  const expressionKeys =
    expressions != null
      ? expressions.flatMap(expr =>
          _.concat(summaryKeysInExpression(expr), configKeysInExpression(expr))
        )
      : [];

  let barData: RunDataPoint[] = runs
    .map(run => {
      if (expressions && expressions.length > 0) {
        const metricsToValue: {[key: string]: number} = {};
        expressionKeys.forEach(
          exprKey =>
            (metricsToValue[keyToString(exprKey)] = Run.getValueSafe(
              run,
              exprKey
            ) as number)
        );
        return expressions.map(expr => {
          const val = evaluateExpression(expr, metricsToValue);
          return {
            metricName: expressionToString(expr),
            value: val,
            run,
            runOrGroupUniqueId: Run.uniqueId(run, groupKeys || []),
            runOrGroupDisplayName:
              groupKeys.length === 0
                ? run.displayName
                : Run.groupedRunDisplayName(run, groupKeys),
          };
        });
      } else {
        return metrics.map(key => {
          const metricName = Run.keyDisplayName(key);
          return {
            run,
            value: Run.getValueSafe(run, key) as number,
            metricName,
            runOrGroupUniqueId: Run.uniqueId(
              run,
              groupKeys ?? [],
              !mergeRunsets
            ),
            runOrGroupDisplayName:
              groupKeys.length === 0
                ? run.displayName
                : Run.groupedRunDisplayName(run, groupKeys),
          };
        });
      }
    })
    .flat();

  if (groupKeys.length > 0 || metrics.length > 1) {
    const groupedBars = _.groupBy(
      barData,
      b =>
        // Can use not-null type assertion because we know we set l.run above.
        b.runOrGroupUniqueId
    );
    const bars = _.flatMap(groupedBars, barSet => {
      return aggregatePoints(
        barSet,
        aggregateCalculations,
        groupAgg ?? 'mean',
        groupArea ?? 'none',
        groupKeys,
        violinPlot ? DEFAULT_VIOLIN_PLOT_BINS : undefined
      );
    }).filter((b): b is GroupDataPoint => b != null);
    barData = bars;
  }

  return barData;
}

function pointResultsForRunset(
  runs: RunsData['filtered'],
  {
    aggregateCalculations,
    aggregateMetrics,
    expressions,
    groupAgg,
    groupArea,
    groupBy,
    legendTemplate,
    metricKeys,
    panelAggregate,
    violinPlot,
  }: PointsFromDataProps,
  q: RunSetInfo | undefined,
  mergeRunsets: boolean
): RunDataPoint[] {
  const runSetID = q?.id;
  const runsForRunset =
    runSetID != null ? runs.filter(r => r.runsetInfo.id === runSetID) : runs;

  const groupKeys = panelAggregate
    ? [Run.key('config', groupBy ?? '')]
    : q?.grouping ?? [];

  const partialParams = {
    runs: runsForRunset,
    groupKeys,
    aggregateCalculations,
    groupArea,
    groupAgg,
    violinPlot,
    legendTemplate,
    mergeRunsets,
  };

  if (aggregateMetrics) {
    return pointsFromRunset({
      metrics: metricKeys,
      expressions,
      ...partialParams,
    });
  }

  return metricKeys.flatMap(y => {
    return pointsFromRunset({
      metrics: [y],
      expressions,
      ...partialParams,
    });
  });
}

const convertRunDataPointsToBars = (props: {
  dataPoints: Array<RunDataPoint | GroupDataPoint>;
  useRunName: boolean;
  useMetricName: boolean;
  legendTemplate: string;
  colorEachMetricDifferently: boolean;
  customRunColors?: RunColorConfig;
  boxPlot?: boolean;
  violinPlot?: boolean;
}): Bar[] => {
  const {dataPoints, customRunColors, violinPlot, legendTemplate} = props;
  const metricToColorIdx = new Map<string, number>();
  if (props.colorEachMetricDifferently) {
    // build a map of metrics to indexes for coloring
    let maxIdx = 0;
    dataPoints.forEach(point => {
      if (!metricToColorIdx.has(point.metricName)) {
        metricToColorIdx.set(point.metricName, maxIdx);
        maxIdx++;
      }
    });
  }

  const bars = dataPoints.map(point => {
    const titleTemplate = parseLegendTemplate(
      legendTemplate,
      true,
      point.run,
      isGroupDataPoint(point) ? point.groupKeys : [],
      prettifyMetricName(point.metricName)
    );

    const key = legendTemplateRemoveCrosshairValues(titleTemplate).trim();

    const pointColor = props.colorEachMetricDifferently
      ? ColorUtil.color(metricToColorIdx.get(point.metricName) ?? 0)
      : // normal coloring
      point.run
      ? ColorUtil.runColor(
          point.run,
          isGroupDataPoint(point) ? point.groupKeys : [],
          customRunColors
        )
      : '000000';

    const uniqueId = point.runOrGroupUniqueId;

    return {
      bins: violinPlot && isGroupDataPoint(point) ? point.bins : undefined,
      color: pointColor,
      displayName: point.runOrGroupDisplayName,
      key,
      mean: isGroupDataPoint(point) ? point.mean : undefined,
      metricName: point.metricName,
      quartiles: isGroupDataPoint(point) ? point.quartiles : undefined,
      range: isGroupDataPoint(point) ? point.range : undefined,
      stddev: isGroupDataPoint(point) ? point.stddev : undefined,
      title: titleTemplate,
      uniqueId,
      value: point.value,
    };
  });

  return bars;
};

interface PointsFromDataProps {
  metricKeys: RunTypes.Key[];
  customRunColors?: RunColorConfig;
  groupBy?: string;
  panelAggregate?: boolean;
  groupAgg?: ChartAggOption;
  groupArea?: AggregateCalculation;
  runSets?: RunSetInfo[];
  aggregateCalculations: AggregateCalculation[];
  colorEachMetricDifferently: boolean;
  aggregateMetrics: boolean;
  boxPlot?: boolean;
  violinPlot?: boolean;
  legendTemplate: string;
  expressions?: Expression[];
}

export const getPointsFromData = (
  runs: RunsData['filtered'],
  props: PointsFromDataProps
) => {
  /* Convert runsdata to barchart data */

  const {
    boxPlot,
    colorEachMetricDifferently,
    customRunColors,
    legendTemplate,
    metricKeys,
    runSets,
    violinPlot,
  } = props;

  if (runs.length === 0) {
    return [];
  }

  let barResults: RunDataPoint[];
  const runSetByID: {[id: string]: RunSetInfo} = {};
  if (runSets != null) {
    // When we have a runset (everywhere but the run page), map over the runsets.
    barResults = runSets.flatMap(runSet =>
      pointResultsForRunset(runs, props, runSet, false)
    );

    runSets.forEach(rs => (runSetByID[rs.id] = rs));
  } else {
    // Else we're in the run page and there's no runset.
    barResults = pointResultsForRunset(runs, props, undefined, false);
  }

  const useMetricName = metricKeys.length > 1;
  const useRunName =
    !useMetricName ||
    _.uniq(barResults.map(br => br.runOrGroupDisplayName)).length > 1;

  const bars = convertRunDataPointsToBars({
    boxPlot,
    colorEachMetricDifferently,
    customRunColors,
    dataPoints: barResults,
    legendTemplate: legendTemplate || '',
    useMetricName,
    useRunName,
    violinPlot,
  });
  return bars;
};

const convertRunDataPointsToScalar = (props: {
  dataPoint: RunDataPoint | GroupDataPoint;
  useRunName: boolean;
  useMetricName: boolean;
  legendTemplate: string;
  customRunColors?: RunColorConfig;
  entityName?: string;
  projectName?: string;
}): Scalar => {
  const {dataPoint, customRunColors, legendTemplate} = props;

  const rootUrl =
    props.entityName != null && props.projectName != null
      ? `/${props.entityName}/${props.projectName}/runs`
      : null;

  const titleTemplate = legendTemplateToFancyLegendProps(
    legendTemplate,
    dataPoint.run,
    [],
    prettifyMetricName(dataPoint.metricName),
    rootUrl ?? undefined
  );

  const pointColor = dataPoint.run
    ? ColorUtil.runColor(dataPoint.run, [], customRunColors)
    : '000000';
  const uniqueId = dataPoint.runOrGroupUniqueId;

  return {
    key: titleTemplate,
    color: pointColor,
    uniqueId,
    value: dataPoint.value,
    range: isGroupDataPoint(dataPoint) ? dataPoint.range : undefined,
    stddev: isGroupDataPoint(dataPoint) ? dataPoint.stddev : undefined,
    stderr: isGroupDataPoint(dataPoint) ? dataPoint.stderr : undefined,
  };
};

interface ScalarFromDataProps {
  metricKeys: RunTypes.Key[];
  customRunColors?: RunColorConfig;
  groupAgg?: ChartAggOption;
  groupArea?: AggregateCalculation;
  runSets?: RunSetInfo[];
  aggregateCalculations: AggregateCalculation[];
  legendTemplate: string;
  expressions?: Expression[];
  entityName?: string;
  projectName?: string;
}

/* Convert runsdata to barchart data */
export const getScalarFromData = (
  runs: RunsData['filtered'],
  props: ScalarFromDataProps
): Scalar | null => {
  if (runs.length === 0) {
    return null;
  }
  const {metricKeys, customRunColors, legendTemplate} = props;

  const barResults = pointResultsForRunset(
    runs,
    {
      ...props,
      groupBy: '',
      panelAggregate: true,
      aggregateMetrics: true,
      colorEachMetricDifferently: false,
    },
    undefined,
    true
  );

  const useMetricName = metricKeys.length > 1;
  const useRunName =
    !useMetricName ||
    _.uniq(barResults.map(br => br.runOrGroupDisplayName)).length > 1;
  if (barResults.length === 0) {
    return {value: 0} as Scalar;
  }

  const runsetForRun = props.runSets?.find(
    rs => rs.id === barResults[0].run.runsetInfo.id
  );
  const entityName = runsetForRun?.entityName ?? props.entityName;
  const projectName = runsetForRun?.projectName ?? props.projectName;

  const scalar = convertRunDataPointsToScalar({
    dataPoint: barResults[0],
    useRunName,
    useMetricName,
    legendTemplate: legendTemplate || '',
    customRunColors,
    entityName,
    projectName,
  });
  return scalar;
};

// Timestamps have long axis tick labels, so we need to add margin for them
type GetPlotMarginParams = {
  axisKeys?: {
    xAxis?: string;
    yAxis?: string;
    zAxis?: string;
  };
  axisDomain?: {
    yAxis?: number[];
  };
  axisType?: {
    yAxis?: YAxisType;
  };
  axisValues?: {
    yAxis?: string[];
  };
  tickTotal?: {
    yAxis?: number;
  };
  fontSize?: PlotFontSize;
  yAxisTickFormatter?: (value: number) => string;
};

type PlotMargin = Record<'top' | 'bottom' | 'left' | 'right', number>;

// In react-vis, the axes are rendered within the plot margins (i.e. not included in the plot width).
// This function calculate the margins dynamically, based on the size of the axis labels.
export const getPlotMargin = ({
  axisKeys = {},
  axisDomain = {},
  axisType = {},
  axisValues = {},
  tickTotal = {},
  fontSize = 'small',
  yAxisTickFormatter,
}: GetPlotMarginParams): PlotMargin => {
  const {xAxis, yAxis, zAxis} = axisKeys;
  const xIsTime = xAxis && Run.isTimeKeyString(xAxis);
  const yIsTime = yAxis && Run.isTimeKeyString(yAxis);
  const zIsTime = zAxis && Run.isTimeKeyString(zAxis);

  const marginDefaults: PlotMargin = {
    top: 8, // distance between the top of the plot and the bottom of the legend
    right: 14, // need a few px margin here because plot points on the very right side extend beyond the plot into the margin
    bottom: plotFontSizeToPx.axis[fontSize] + axisLabelMargin,
    left: 0,
  };
  const marginOverrides: Partial<PlotMargin> = {};

  if (axisValues?.yAxis != null) {
    marginOverrides.left = getYAxisWidth(axisValues.yAxis, fontSize);
  }

  if (axisDomain?.yAxis) {
    const args = {
      type: axisType.yAxis ?? 'linear',
      domain: axisDomain.yAxis,
    };

    // This is the internal function used by react vis to calculate the axis values
    const scale = getScaleFnFromScaleObject(args);

    const scaleTicks = scale.ticks(tickTotal?.yAxis);
    const formatFn = yIsTime ? Run.formatTimestamp : formatYAxis;

    marginOverrides.left = getYAxisWidth(
      scaleTicks.map(yAxisTickFormatter ?? formatFn),
      fontSize
    );
  }

  // TODO: Check these {x|y|z}IsTime override numbers.
  // For some reason I can no longer select a time value (run.createdAt or run.heartbeatAt) for plot axes,
  // which I think is a regression. We can't test these margins until that's fixed.
  if (xIsTime) {
    marginOverrides.bottom = 55;
    marginOverrides.left = 55;
  }
  if (yIsTime) {
    marginOverrides.left = 100;
  }
  if (zIsTime) {
    marginOverrides.top = 20;
  }

  return {
    ...marginDefaults,
    ...marginOverrides,
  };
};

export function getMetricIdentifiersFromExpressions(
  expressions?: Expression[],
  xExpression?: Expression
): {
  xExpressionMetricIdentifiers: string[];
  expressionMetricIdentifiers: string[];
} {
  const xExpressionMetricIdentifiers =
    xExpression != null ? metricsInExpression(xExpression) : [];
  const expressionMetricIdentifiers =
    expressions != null
      ? expressions.flatMap(expr => metricsInExpression(expr))
      : [];
  return {xExpressionMetricIdentifiers, expressionMetricIdentifiers};
}

export function getAllMetrics(
  metrics: string[],
  expressionMetricIdentifiers: string[],
  xExpressionMetricIdentifiers: string[]
): string[] {
  return _.uniq(
    _.concat(xExpressionMetricIdentifiers, expressionMetricIdentifiers, metrics)
  );
}
