package it.unimi.dsi.law.big.stat;

import static it.unimi.dsi.fastutil.BigArrays.get;

/*
 *  Copyright (C) 2011-2020 Paolo Boldi, Massimo Santini and Sebastiano Vigna
 *
 *  This program is free software; you can redistribute it and/or modify it
 *  under the terms of the GNU General Public License as published by the Free
 *  Software Foundation; either version 3 of the License, or (at your option)
 *  any later version.
 *
 *  This program is distributed in the hope that it will be useful, but
 *  WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 *  or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 *  for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, see <http://www.gnu.org/licenses/>.
 *
 */

import static it.unimi.dsi.fastutil.BigArrays.length;
import static it.unimi.dsi.fastutil.BigArrays.set;
import static it.unimi.dsi.fastutil.BigArrays.wrap;
import static it.unimi.dsi.law.big.stat.WeightedTau.HYPERBOLIC_WEIGHER;
import static it.unimi.dsi.law.big.stat.WeightedTau.LOGARITHMIC_WEIGHER;
import static it.unimi.dsi.law.big.stat.WeightedTau.QUADRATIC_WEIGHER;
import static org.junit.Assert.assertEquals;

import java.io.File;
import java.io.IOException;

import org.junit.Test;

import it.unimi.dsi.Util;
import it.unimi.dsi.fastutil.BigArrays;
import it.unimi.dsi.fastutil.doubles.DoubleBigArrays;
import it.unimi.dsi.fastutil.io.BinIO;
import it.unimi.dsi.fastutil.io.TextIO;
import it.unimi.dsi.fastutil.longs.Long2DoubleFunction;
import it.unimi.dsi.fastutil.longs.LongBigArrays;
import it.unimi.dsi.util.XoRoShiRo128PlusRandom;

//RELEASE-STATUS: DIST

public class WeightedTauTest {
	private static final Long2DoubleFunction CONSTANT_WEIGHER = new WeightedTau.AbstractWeigher() {
		private static final long serialVersionUID = 1L;
		@Override
		public double get(final long key) {
			return 1;
		}
	};
	private static final Long2DoubleFunction[] WEIGHER = new Long2DoubleFunction[] { CONSTANT_WEIGHER, HYPERBOLIC_WEIGHER, LOGARITHMIC_WEIGHER, QUADRATIC_WEIGHER };
	private final double[][] ordered = wrap(new double[] { 0.0, 1.0, 2.0, 3.0, 4.0 });
	private final double[][] reverse = wrap(new double[] { 4.0, 3.0, 2.0, 1.0, 0.0 });
	private final double[][] reverseButOne = wrap(new double[] { 10.0, 9.0, 7.0, 8.0, 6.0 });
	private final double[][] allOnes = wrap(new double[] { 1.0, 1.0, 1.0 });
	private final double[][] allZeroes = wrap(new double[] { 0.0, 0.0, 0.0 });

	public static double compute(final Long2DoubleFunction weigher, final boolean multiplicative, final double[][] bv0, final double[][] bv1, final long brank[][]) {
		final double[] v0 = new double[(int) length(bv0)], v1 = new double[(int) length(bv1)];
		BigArrays.copyFromBig(bv0, 0, v0, 0, v0.length);
		BigArrays.copyFromBig(bv1, 0, v1, 0, v1.length);
		final int[] rank;
		if (brank != null) {
			rank = new int[(int) length(brank)];
			for(int i = 0; i < rank.length; i++) rank[i] = (int) get(brank, i);
		}
		else rank = null;

		return it.unimi.dsi.law.stat.WeightedTauTest.compute(key -> weigher.get(key), multiplicative, v0, v1, rank);
	}

	@Test
	public void testComputeOrdered() {
		for (final boolean multiplicative: new boolean[] { false, true}) {
			for (final Long2DoubleFunction weigher : WEIGHER) {
				final WeightedTau weightedTau = new WeightedTau(weigher, multiplicative);
				final double expResult = compute(weigher, multiplicative, ordered, ordered, null); // (10.0 - 0.0) / 10.0;
				final double result = weightedTau.compute(ordered, ordered, null);
				assertEquals(expResult, result, 1E-15);
			}
		}
	}

	@Test
	public void testComputeWithReverse() {
		for (final boolean multiplicative: new boolean[] { false, true}) {
			for (final Long2DoubleFunction weigher : WEIGHER) {
				final WeightedTau weightedTau = new WeightedTau(weigher, multiplicative);
				double expResult = compute(weigher, multiplicative, ordered, reverse, null);// (0 - 10.0) / 10.0;
				double result = weightedTau.compute(ordered, reverse, null);
				assertEquals(expResult, result, 1E-15);

				expResult = compute(weigher, multiplicative, reverse, ordered, null);// (0 - 10.0) / 10.0;
				result = weightedTau.compute(reverse, ordered, null);
				assertEquals(expResult, result, 1E-15);
			}
		}
	}

	@Test
	public void testComputeWithReverseButOne() {
		for (final boolean multiplicative: new boolean[] { false, true}) {
			for (final Long2DoubleFunction weigher : WEIGHER) {
				final WeightedTau weightedTau = new WeightedTau(weigher, multiplicative);
				double expResult = compute(weigher, multiplicative, ordered, reverseButOne, null);// (1.0 - 9.0) / 10.0;
				double result = weightedTau.compute(ordered, reverseButOne, null);
				assertEquals(expResult, result, 1E-15);

				expResult = compute(weigher, multiplicative, reverseButOne, ordered, null);// (1.0 - 9.0) / 10.0;
				result = weightedTau.compute(reverseButOne, ordered, null);
				assertEquals(expResult, result, 1E-15);
			}
		}
	}

	@Test
	public void testTies() {
		for (final boolean multiplicative: new boolean[] { false, true}) {
			for (final Long2DoubleFunction weigher : WEIGHER) {
				final WeightedTau weightedTau = new WeightedTau(weigher, multiplicative);
				final double[][] v0 = wrap(new double[] { 0.1, 0.1, 0.2 });
				final double[][] v1 = wrap(new double[] { 0.4, 0.3, 0.3 });

				double expResult = compute(weigher, multiplicative, v0, v1, null);
				double result = weightedTau.compute(v0, v1, null);
				assertEquals(expResult, result, 1E-15);

				expResult = compute(weigher, multiplicative, v1, v0, null);
				result = weightedTau.compute(v1, v0, null);
				assertEquals(expResult, result, 1E-15);
			}
		}
	}

	@Test
	public void testRandom() {
		for (final boolean multiplicative: new boolean[] { false, true}) {
			for (final Long2DoubleFunction weigher : WEIGHER) {
				final XoRoShiRo128PlusRandom XoRoShiRo128PlusRandom = new XoRoShiRo128PlusRandom(0);
				final WeightedTau weightedTau = new WeightedTau(weigher, multiplicative);
				final double[][] v0 = DoubleBigArrays.newBigArray(1000);
				final double[][] v1 = DoubleBigArrays.newBigArray(1000);

				for (int i = v0.length; i-- != 0;) {
					set(v0, i, XoRoShiRo128PlusRandom.nextDouble());
					set(v1, i, XoRoShiRo128PlusRandom.nextDouble());
				}
				double expResult = compute(weigher, multiplicative, v0, v1, null);
				double result = weightedTau.compute(v0, v1, null);
				assertEquals(expResult, result, 1E-10);

				expResult = compute(weigher, multiplicative, v1, v0, null);
				result = weightedTau.compute(v1, v0, null);
				assertEquals(expResult, result, 1E-10);
			}
		}
	}

	@Test
	public void testRandomWithTies() {
		for (final boolean multiplicative: new boolean[] { false, true}) {
			for (final Long2DoubleFunction weigher : WEIGHER) {
				final WeightedTau weightedTau = new WeightedTau(weigher, multiplicative);
				final XoRoShiRo128PlusRandom XoRoShiRo128PlusRandom = new XoRoShiRo128PlusRandom(0);
				final double[][] v0 = DoubleBigArrays.newBigArray(1000);
				final double[][] v1 = DoubleBigArrays.newBigArray(1000);
				for (int i = v0.length; i-- != 0;) {
					set(v0, i, XoRoShiRo128PlusRandom.nextInt(10));
					set(v1, i, XoRoShiRo128PlusRandom.nextInt(10));
				}

				double expResult = compute(weigher, multiplicative, v0, v1, null);
				double result = weightedTau.compute(v0, v1, null);
				assertEquals(expResult, result, 1E-10);

				expResult = compute(weigher, multiplicative, v1, v0, null);
				result = weightedTau.compute(v1, v0, null);
				assertEquals(expResult, result, 1E-10);
			}
		}
	}

	@Test
	public void testRandomWithTiesAndRank() {
		for (final boolean multiplicative: new boolean[] { false, true}) {
			for (final Long2DoubleFunction weigher : WEIGHER) {
				final XoRoShiRo128PlusRandom XoRoShiRo128PlusRandom = new XoRoShiRo128PlusRandom(0);
				final WeightedTau weightedTau = new WeightedTau(weigher, multiplicative);
				final double[][] v0 = DoubleBigArrays.newBigArray(1000);
				final double[][] v1 = DoubleBigArrays.newBigArray(1000);
				for (int i = v0.length; i-- != 0;) {
					set(v0, i, XoRoShiRo128PlusRandom.nextInt(10));
					set(v1, i, XoRoShiRo128PlusRandom.nextInt(10));
				}
				final long[][] rank = Util.identity(length(v0));
				LongBigArrays.shuffle(rank,  XoRoShiRo128PlusRandom);

				double expResult = compute(weigher, multiplicative, v0, v1, rank);
				double result = weightedTau.compute(v0, v1, rank);
				assertEquals(expResult, result, 1E-10);

				expResult = compute(weigher, multiplicative, v1, v0, rank);
				result = weightedTau.compute(v1, v0, rank);
				assertEquals(expResult, result, 1E-10);
			}
		}
	}


	@Test
	public void testAllTies() {
		for (final Long2DoubleFunction weigher : WEIGHER) {
			assertEquals(1.0, new WeightedTau(weigher).compute(allOnes, allZeroes, null), 1E-15);
		}
	}

	@Test
	public void testInputType() throws IOException {
		final File a = File.createTempFile(WeightedTauTest.class.getSimpleName(), "a");
		a.deleteOnExit();
		final File b = File.createTempFile(WeightedTauTest.class.getSimpleName(), "b");
		b.deleteOnExit();
		for(final boolean reverse: new boolean[] { true, false }) {
			for (final Long2DoubleFunction weigher : WEIGHER) {
				final WeightedTau weightedTau = new WeightedTau(weigher);
				BinIO.storeInts(new int[] { 0, 1, 2, 3, 4 }, a);
				BinIO.storeInts(new int[] { 4, 3, 2, 1, 0 }, b);
				assertEquals(-1, weightedTau.computeInts(a.toString(), b.toString(), reverse), 1E-15);
				// TODO: main test
				assertEquals(-1, weightedTau.compute(a.toString(), Integer.class, b.toString(), Integer.class, reverse), 1E-15);
				BinIO.storeInts(new int[] { 0, 1, 2, 3, 4 }, b);
				assertEquals(1, weightedTau.computeInts(a.toString(), b.toString(), reverse), 1E-15);
				assertEquals(1, weightedTau.compute(a.toString(), Integer.class, b.toString(), Integer.class, reverse), 1E-15);

				BinIO.storeLongs(new long[] { 0, 1, 2, 3, 4 }, a);
				BinIO.storeLongs(new long[] { 4, 3, 2, 1, 0 }, b);
				assertEquals(-1, weightedTau.computeLongs(a.toString(), b.toString(), reverse), 1E-15);
				assertEquals(-1, weightedTau.compute(a.toString(), Long.class, b.toString(), Long.class, reverse), 1E-15);
				BinIO.storeLongs(new long[] { 0, 1, 2, 3, 4 }, b);
				assertEquals(1, weightedTau.computeLongs(a.toString(), b.toString(), reverse), 1E-15);
				assertEquals(1, weightedTau.compute(a.toString(), Long.class, b.toString(), Long.class, reverse), 1E-15);

				BinIO.storeFloats(new float[] { 0, 1, 2, 3, 4 }, a);
				BinIO.storeFloats(new float[] { 4, 3, 2, 1, 0 }, b);
				assertEquals(-1, weightedTau.computeFloats(a.toString(), b.toString(), reverse), 1E-15);
				assertEquals(-1, weightedTau.compute(a.toString(), Float.class, b.toString(), Float.class, reverse), 1E-15);
				BinIO.storeFloats(new float[] { 0, 1, 2, 3, 4 }, b);
				assertEquals(1, weightedTau.computeFloats(a.toString(), b.toString(), reverse), 1E-15);
				assertEquals(1, weightedTau.compute(a.toString(), Float.class, b.toString(), Float.class, reverse), 1E-15);

				BinIO.storeDoubles(new double[] { 0, 1, 2, 3, 4 }, a);
				BinIO.storeDoubles(new double[] { 4, 3, 2, 1, 0 }, b);
				assertEquals(-1, weightedTau.computeDoubles(a.toString(), b.toString(), reverse), 1E-15);
				assertEquals(-1, weightedTau.compute(a.toString(), Double.class, b.toString(), Double.class, reverse), 1E-15);
				BinIO.storeDoubles(new double[] { 0, 1, 2, 3, 4 }, b);
				assertEquals(1, weightedTau.computeDoubles(a.toString(), b.toString(), reverse), 1E-15);
				assertEquals(1, weightedTau.compute(a.toString(), Double.class, b.toString(), Double.class, reverse), 1E-15);

				BinIO.storeInts(new int[] { 0, 1, 2, 3, 4 }, a);
				BinIO.storeDoubles(new double[] { 4, 3, 2, 1, 0 }, b);
				assertEquals(-1, weightedTau.compute(a.toString(), Integer.class, b.toString(), Double.class, reverse), 1E-15);
				BinIO.storeDoubles(new double[] { 0, 1, 2, 3, 4 }, b);
				assertEquals(1, weightedTau.compute(a.toString(), Integer.class, b.toString(), Double.class, reverse), 1E-15);

				BinIO.storeDoubles(new double[] { 0, 1, 2, 3, 4 }, a);
				BinIO.storeLongs(new long[] { 4, 3, 2, 1, 0 }, b);
				assertEquals(-1, weightedTau.compute(a.toString(), Double.class, b.toString(), Long.class, reverse), 1E-15);
				BinIO.storeLongs(new long[] { 0, 1, 2, 3, 4 }, b);
				assertEquals(1, weightedTau.compute(a.toString(), Double.class, b.toString(), Long.class, reverse), 1E-15);

				TextIO.storeDoubles(new double[] { 0, 1, 2, 3, 4 }, a);
				BinIO.storeLongs(new long[] { 4, 3, 2, 1, 0 }, b);
				assertEquals(-1, weightedTau.compute(a.toString(), String.class, b.toString(), Long.class, reverse), 1E-15);
				BinIO.storeLongs(new long[] { 0, 1, 2, 3, 4 }, b);
				assertEquals(1, weightedTau.compute(a.toString(), String.class, b.toString(), Long.class, reverse), 1E-15);

			}
		}
		a.delete();
		b.delete();
	}
}
