import numpy as np
import matplotlib.pyplot as plt
import textwrap as txtwrap

maxNameLength = 40


def plotLines(infosDir, relDocsPerTurn, setName, plotTitle, xs, ys, legend, axis=None):
  # create figure
  plt.figure(figsize=(16, 8))
  # define title
  plt.title(plotTitle)
  # define axis
  if axis is not None:
    plt.axis(axis)
  # define grid
  plt.grid(True)
  # get plots
  plots = []
  for idx in range(len(legend)):
    this_plot, = plt.plot(xs[idx], ys[idx])
    plots.append(this_plot)
  # create plots with legend
  plt.legend(plots, legend)
  # include labels
  plt.tight_layout()
  plt.savefig(infosDir + "/final_plots/" + setName + "/" + str(relDocsPerTurn) + "/" + plotTitle + ".png")
  plt.show()


def plotPoints(infosDir, relDocsPerTurn, setName, plotTitle, xs, ys, legend, axis=None):
  # points
  markersSymbols = ["o", "o", "^", "^", "v", "v", "s", "s", "P", "P"]#["x", "o", "v", "^", "s", "P", "*", "X", "D", "p"]
  # create figure
  plt.figure(figsize=(16, 8))
  # define title
  plt.title(plotTitle)
  # define axis
  if axis is not None:
    plt.axis(axis)
  # define grid
  plt.grid(True)
  # get plots
  plots = []
  for idx in range(len(legend)):
    plt.scatter(xs[idx], ys[idx], label=legend[idx], marker=markersSymbols[idx % len(markersSymbols)])
  # create plots with legend
  plt.legend()
  # include labels
  plt.tight_layout()
  plt.savefig(infosDir + "/final_plots/" + setName + "/" + str(relDocsPerTurn) + "/" + plotTitle + ".png")
  plt.show()


def plotMetricAlongConversation(infosDir, relDocsPerTurn, setName, metricName, matrices, modelsNames, convNumbers, preName=""):
  # convs turns names
  convsTurnsNames = []
  # models turns mean
  modelsTurnsMean = []
  for modelID in range(len(matrices)):
    modelsTurnsMean.append(np.nanmean(matrices[modelID], axis=0).tolist())
    convsTurnsNames.append(range(1, len(matrices[0][0]) + 1))
  # plot turns mean
  plotLines(infosDir, relDocsPerTurn, setName, preName + "Per turn score for " + metricName, convsTurnsNames, modelsTurnsMean, modelsNames)
  # each conversation
  for convID in range(len(matrices[0])):
    # models conv
    modelsConv = []
    for modelID in range(len(matrices)):
      modelsConv.append(matrices[modelID][convID])
    # plot conv
    plotLines(infosDir, relDocsPerTurn, setName, preName + "Per turn score for " + metricName + " on conversation " + str(convNumbers[convID]), convsTurnsNames, modelsConv, modelsNames)


def plotMetricEachConversation(infosDir, relDocsPerTurn, setName, metricName, matrices, modelsNames, convNumbers, convNames, preName=""):
  # convs names
  convsNames = []
  # models convs mean
  modelsConvsMean = []
  for modelID in range(len(matrices)):
    tempConvsNames = []
    for name in convNames[:, 0]:
      tempConvsNames.append(txtwrap.shorten(name, width=maxNameLength, placeholder="..."))
    convsNames.append(tempConvsNames)
    modelsConvsMean.append(np.nanmean(matrices[modelID], axis=1))
  # plot turns mean
  plotPoints(infosDir, relDocsPerTurn, setName, preName + metricName + " comparison", modelsConvsMean, convsNames, modelsNames)
  # each conversation
  for convID in range(len(matrices[0])):
    # models conv
    modelsConv = []
    # turns names
    convsTurnsNames = []
    for modelID in range(len(matrices)):
      modelsConv.append(matrices[modelID][convID])
      tempTurnsNames = []
      for name in convNames[convID, 1:]:
        tempTurnsNames.append(txtwrap.shorten(name, width=maxNameLength, placeholder="..."))
      convsTurnsNames.append(tempTurnsNames)
    # plot conv
    plotPoints(infosDir, relDocsPerTurn, setName, preName + metricName + " comparison on conversation " + str(convNumbers[convID]), modelsConv, convsTurnsNames, modelsNames)


def plotPrecisionRecall(infosDir, relDocsPerTurn, setName, recall_matrices, precision_matrices, modelsNames, convNumbers, preName=""):
  # x values
  xValues = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
  # each conversation
  for convID in range(len(recall_matrices[0])):
    # data
    precisions = []
    recalls = []
    # each model
    for modelID in range(len(recall_matrices)):
      precisions.append(np.interp(xValues, recall_matrices[modelID][convID], precision_matrices[modelID][convID]))
      recalls.append(xValues)
    plotLines(infosDir, relDocsPerTurn, setName, preName + "Precision-Recall on conversation " + str(convNumbers[convID]), recalls, precisions, modelsNames, [0.0, 1.0, 0.0, 1.0])
  # mean
  precisions = []
  recalls = []
  for modelID in range(len(recall_matrices)):
    precision = np.nanmean(precision_matrices[modelID], axis=0)
    recall = np.nanmean(recall_matrices[modelID], axis=0)
    precisions.append(np.interp(xValues, recall, precision))
    recalls.append(xValues)
  plotLines(infosDir, relDocsPerTurn, setName, preName + "Precision-Recall", recalls, precisions, modelsNames, [0.0, 1.0, 0.0, 1.0])