package br.ufrgs.inf.prosoft.tfcache;

import br.ufrgs.inf.prosoft.tfcache.configuration.Configuration;
import br.ufrgs.inf.prosoft.tfcache.metadata.Method;
import br.ufrgs.inf.prosoft.tfcache.metadata.Occurrence;

import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.stream.Stream;

/**
 * @author romulo
 */
public class Simulator {

  private final Pareto pareto;

  public Simulator() {
    this.pareto = new Pareto();
  }

  public static void simulate(List<Occurrence> occurrences, Pareto pareto) {
    Pareto computeIfAbsent = StorageManager.computeIfAbsent(occurrences, () -> {
      if ("exhaustive".equals(Configuration.getKernel())) simulate(occurrences, generateAllTTLs(occurrences), pareto);
      else if ("optimised".equals(Configuration.getKernel())) simulate(occurrences, generateTTLsOfInterest(occurrences), pareto);
      else testKernels(occurrences);
      return pareto;
    });
    computeIfAbsent.values().forEach(pareto::addIfPareto);
  }

  public static void simulate(List<Occurrence> occurrences, Stream<Long> ttls, Pareto pareto) {
    ttls.forEach(actualTTL -> simulate(occurrences.stream(), actualTTL, pareto));
  }

  public static void simulate(Stream<Occurrence> occurrences, long ttl, Pareto pareto) {
    Map<String, Long> inputHasCachedTime = new HashMap<>();
    Map<String, Object> inputHasOutput = new HashMap<>();

    long blindedSavedTime = 0;
    long realSavedTime = 0;
    long hits = 0;
    long computationTime = 0;
    long timeInCache = 0;
    long stales = 0;

    Iterator<Occurrence> iterator = occurrences.iterator();
    while (iterator.hasNext()) {
      Occurrence occurrence = iterator.next();
      long adjustedStartTime = occurrence.getStartTime() - blindedSavedTime;
      long adjustedEndTime = occurrence.getEndTime() - blindedSavedTime;

      if (occurrence.getExecutionTime() < 0) throw new RuntimeException("executionTime cannot be under zero");
      if (adjustedEndTime < adjustedStartTime) throw new RuntimeException("adjustedEndTime should not be lesser than adjustedStartTime");

      if (inputHasCachedTime.containsKey(occurrence.getParametersSerialised()) &&
        adjustedStartTime - inputHasCachedTime.get(occurrence.getParametersSerialised()) > ttl) {
        inputHasCachedTime.remove(occurrence.getParametersSerialised());
      }
      if (inputHasCachedTime.containsKey(occurrence.getParametersSerialised())) {
        if (Configuration.getStaleness().equals("shrink")) {
          if (Objects.deepEquals(inputHasOutput.get(occurrence.getParametersSerialised()), occurrence.getReturnValue())) {
            realSavedTime += occurrence.getExecutionTime();
          } else {
            stales++;
          }
        }
        blindedSavedTime += occurrence.getExecutionTime();
        hits++;
      } else {
        inputHasCachedTime.put(occurrence.getParametersSerialised(), adjustedEndTime);
        if (Configuration.getStaleness().equals("shrink")) inputHasOutput.put(occurrence.getParametersSerialised(), occurrence.getReturnValue());
        computationTime += occurrence.getExecutionTime();
        timeInCache += ttl;
      }
    }
    Metrics metrics = new Metrics(ttl, hits, timeInCache, computationTime, stales, Configuration.getStaleness().equals("shrink") ? realSavedTime : blindedSavedTime);
    pareto.addIfPareto(metrics);
  }

  public static Stream<Long> generateTTLsOfInterest(List<Occurrence> occurrences) {
    List<Long> windows = new ArrayList<>();
    for (int hits = 1; hits < occurrences.size(); hits++) {
      long window = occurrences.get(hits).getStartTime() - occurrences.get(hits - 1).getEndTime();
      if (window > 0) windows.add(window);
    }

    Set<Long> ttlsOfInterest = new HashSet<>(windows);
    for (int hits = 2; hits <= windows.size(); hits++) {
      for (int shift = 0; shift <= windows.size() - hits; shift++) {
        long ttl = 0;
        for (int k = shift; k < shift + hits; k++) ttl += windows.get(k);
        ttlsOfInterest.add(ttl);
      }
    }
    return ttlsOfInterest.stream().parallel();
  }

  public static Stream<Long> generateAllTTLs(List<Occurrence> occurrences) {
    long maxTTL = occurrences.get(occurrences.size() - 1).getStartTime() - occurrences.get(0).getEndTime();
    long minTTL = Long.MAX_VALUE;
    for (int i = 0; i < occurrences.size() - 1; i++) {
      long ttl = occurrences.get(i + 1).getStartTime() - occurrences.get(i).getEndTime();
      if (ttl > 0 && ttl < minTTL) minTTL = ttl;
    }
    return LongStream.rangeClosed(minTTL, maxTTL).boxed().parallel();
  }

  private static void testKernels(List<Occurrence> occurrences) {
    Pareto optimisedPareto = new Pareto();
    Pareto exhaustivePareto = new Pareto();

    Map<String, List<Occurrence>> inputHasOccurrences = Method.groupByInput(occurrences);
    Set<Long> ttlsOfInterest = inputHasOccurrences.values().stream()
      .map(Simulator::generateTTLsOfInterest)
      .reduce(Stream::concat)
      .orElse(Stream.empty())
      .collect(Collectors.toSet());

    simulate(occurrences, ttlsOfInterest.stream(), optimisedPareto);
    simulate(occurrences, generateAllTTLs(occurrences), exhaustivePareto);

    List<Long> missingTTLs = exhaustivePareto.values().stream().map(Metrics::getTtl)
      .filter(ttl -> !ttlsOfInterest.contains(ttl))
      .sorted()
      .collect(Collectors.toList());
    if (!missingTTLs.isEmpty()) {
      System.out.println("=== " + Configuration.getInput() + " ===");
      System.out.println("\tMissing ttls: " + missingTTLs);
    }

    Metrics maxExhaustiveMetrics = exhaustivePareto.getBestMetrics();
    Metrics maxOptimisedMetrics = optimisedPareto.getBestMetrics();
    if (maxExhaustiveMetrics.getTtl() != maxOptimisedMetrics.getTtl()) {
      System.out.println("=== " + Configuration.getInput() + " ===");
      System.out.println("\tDIFFERENT BEST METRICS");
      System.out.println("\tOptimised: " + maxOptimisedMetrics);
      System.out.println("\tExhaustive: " + maxExhaustiveMetrics);
      if (maxExhaustiveMetrics.compareTo(maxOptimisedMetrics) < 0) System.out.println("\tOptimised won");
      else if (maxExhaustiveMetrics.compareTo(maxOptimisedMetrics) > 0) System.out.println("\tExhaustive won");
      else System.out.println("\tEquivalent recommendation");
    }
  }

  public Pareto getPareto() {
    return pareto;
  }

  public void simulate(List<Occurrence> occurrences) {
    simulate(occurrences, this.pareto);
  }

  public void simulate(Stream<Occurrence> occurrences, long ttl) {
    simulate(occurrences, ttl, this.pareto);
  }

}
