package it.unimi.dsi.law.fibrations;

/*
 * Copyright (C) 2005-2020 Paolo boldi and Sebastiano Vigna
 *
 *  This program is free software; you can redistribute it and/or modify it
 *  under the terms of the GNU General Public License as published by the Free
 *  Software Foundation; either version 3 of the License, or (at your option)
 *  any later version.
 *
 *  This program is distributed in the hope that it will be useful, but
 *  WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 *  or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 *  for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, see <http://www.gnu.org/licenses/>.
 *
 */

import java.io.IOException;
import java.util.Arrays;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntArrays;
import it.unimi.dsi.fastutil.ints.IntComparator;
import it.unimi.dsi.fastutil.io.BinIO;
import it.unimi.dsi.webgraph.ImmutableGraph;
import it.unimi.dsi.webgraph.LazyIntIterator;
import it.unimi.dsi.webgraph.NodeIterator;

//RELEASE-STATUS: DIST

/**
 * Static methods to compute the minimum fibration base of a given graph using a partwise algorithm.
 * More precisely, the method
 * {@link #compute(ImmutableGraph, NodeColouringStrategy, ArcColouringStrategy)} starts from a graph
 * (possibly with a colouring on its nodes and/or on its arcs) and returns an array, say
 * <code>a[]</code>, with exactly as many elements as there are nodes in the graph, and with the
 * following properties:
 * <ul>
 * <li><code>a[x]==a[y]</code> iff <code>x</code> and <code>y</code> are in the same fibre of the
 * minimal (coloured, if the graph is such) fibrations;
 * <li>the values contained in <code>a[]</code> range from 0 to <code>k&minus;1</code> where
 * <code>k</code> is the number of nodes in the minimum base;
 * <li>the values in <code>a[]</code> are assigned <em>canonically</em>, that is, if
 * <code>b[]</code> is the array returned by the method on an isomorphic graph (with the same
 * colours, if the graph is coloured) and if <code>f</code> represents the isomorphism, then
 * <code>a[x]=b[f(x)]</code> for every node <code>x</code>.
 * </ul>
 *
 * <h2>Algorithm implementation</h2>
 *
 * <p>
 * The algorithm is a partwise variant of that implemented by
 * {@link it.unimi.dsi.law.fibrations.MinimumBase}. The algorithm keeps a list of touched
 * <em>parts</em>, as opposed to a list of touched <em>nodes</em>, resulting in less memory
 * consumption and, in practice, in faster execution and reduced memory usage, even if it is not
 * possible for this implementation to prove the good bounds given for
 * {@link it.unimi.dsi.law.fibrations.MinimumBase}. The canonical labels created by the two
 * algorithms, however, are identical.
 *
 * <p>
 * In the following, by <em>partition of a set</em> we mean a subdivision of the set into a number
 * of nonempty disjoint subsets, called <em>parts</em>.
 *
 * <p>
 * The algorithm execution happens in rounds; at the end of each round, a certain partition of the
 * nodes is established. The starting partition is the one determined by node colours, if any, or it
 * is simply the trivial partition with just one part. At every round, the old partition is refined
 * (i.e., some of the parts are further subdivided into subparts). The algorithm stops as soon as no
 * part is actually subdivided at the end of a round: the final partition is the desired one (i.e.,
 * the nodes are partitioned according to the fibres of the minimum fibration).
 *
 * <h3>Basic data structures</h3>
 *
 * <p>
 * <strong>Current partition:</strong> The current partition is stored into two arrays: the first,
 * called <code>part[]</code> simply contains a permutation of the nodes with the property that
 * nodes belonging to the same part appear consecutively; the second, called <code>start[]</code>
 * contains, for each node, the index of <code>part[]</code> where its class belongs to. More
 * formally, suppose that <code>part[begin]</code>, <code>part[begin+1]</code>, &hellip;,
 * <code>part[end-1]</code> is one of the parts; then, if <code>x=part[j]</code> for some
 * <code>j</code> between <code>begin</code> and <code>end</code> (i.e., if <code>x</code> is one of
 * the nodes in the part) we have <code>start[x]=begin</code>. In the following, unless otherwise
 * specified, we shall identify a part with its starting index in the array <code>part</code>.
 *
 * <p>
 * <strong>Active parts:</strong> At the beginning of each round, there is a certain set of active
 * parts; their number is stored in <code>numActiveParts</code>, and they are stored in the
 * <code>startActivePart[]</code> array, in arbitrary order.
 *
 * <p>
 * <strong>Touched parts:</strong> During each round, some of the parts are deemed as touched; their
 * number (at the end of the round) is stored in <code>touchedListLength</code>, and they are stored
 * in the <code>touchedList[]</code> array; additionally, there is an array of boolean values,
 * called <code>touched[]</code> such that <code>touched[i]</code> is true iff <code>i</code> is the
 * starting index of a touched part.
 *
 * <h3>First phase: assigning labels to nodes</h3>
 *
 * <p>
 * The final aim of the first phase is to assign to each node <code>x</code> a label that is the
 * list of all nodes <code>y</code> that have an arc towards <code>x</code> and appear in some
 * active parts; such labels (whose length cannot be larger than the indegree of <code>x</code>)
 * will be contained in the array <code>inFrom[x][]</code>, its length being stored in
 * <code>inFill[x]</code>.
 *
 * <p>
 * To obtain this result, the algorithm scans all the nodes in all the active parts, and for each
 * such node <code>y</code> considers all outgoing arcs, writing <code>y</code> in all the labels of
 * the target nodes of such arcs.
 *
 * <p>
 * In this phase, we mark all parts containing at least one target node as <em>touched</em>.
 *
 * <h3>Second phase: refining touched parts</h3>
 *
 * <p>
 * We consider all touched parts, in the (arbitrary) order in which we find them in the touched
 * list. In this phase (some of) these parts will be partitioned into subparts: this amounts in
 * permuting the portion of the array <code>part[]</code> where the part is stored, and changing
 * some of the entries of the array <code>start[]</code> (those that are relative to nodes in the
 * part that is being subpartitioned).
 *
 * <p>
 * Note that at a certain point of this phase we have some parts that have already been
 * subpartitioned (we call them <em>completed</em>), a part that is being considered (we call it
 * <em>current</em>) and some other touched parts that will be considered later on.
 *
 * <p>
 * Suppose that the current part starts at index <code>begin</code> and ends at index
 * <code>end</code> (exclusive). First of all, for each node <code>x</code> in the part, the label
 * <code>inFrom[x][0..inFill[x]-1]</code> is sorted according to the following lexicographic order:
 * <ul>
 * <li>if the colour of the arc (<code>y</code>,<code>x</code>) is smaller than the colour of
 * (<code>y'</code>,<code>x</code>), then <code>y</code> must appear before <code>y'</code>;
 * <li>if the colour of the arc (<code>y</code>,<code>x</code>) is the same as the colour of
 * (<code>y'</code>,<code>x</code>), but the part of <code>y</code> is smaller than the part of
 * <code>y'</code>, then <code>y</code> must appear before <code>y'</code>.
 * </ul>
 *
 * <p>
 * After sorting, the current part is subdivided according to the equivalence relation induced by
 * the previously described lexicographic order. Some care must be taken, though: when comparing the
 * parts of <code>y</code> and <code>y'</code> we are considering the new partitioning for the
 * completed parts, but we use the old partitioning for the current part (i.e., nodes in the current
 * part are considered to have partition number <code>begin</code>).
 */

public final class PartwiseMinimumBase {
	private static final Logger LOGGER = LoggerFactory.getLogger(PartwiseMinimumBase.class);
	private static final boolean ASSERTS = true;

	private PartwiseMinimumBase() {}

	private static final class ColourPartComparator implements IntComparator {
		public int targetNode;
		private final ArcColouringStrategy colouringStrategy;
		private final int[] start;
		private final boolean hasColours;

		public ColourPartComparator(final int[] start, final ArcColouringStrategy colouringStrategy) {
			this.colouringStrategy = colouringStrategy;
			this.start = start;
			this.hasColours = colouringStrategy != null;
		}

		@Override
		public int compare(final int i, final int j) {
			if (hasColours) {
				final int diff = colouringStrategy.colour(i, targetNode) - colouringStrategy.colour(j, targetNode);
				if (diff != 0) return diff;
			}
			return start[i] - start[j];
		}
	}

	private static final class NodeLengthLexComparator implements IntComparator {
		private final int[] inFill;
		private final int[][] inFrom;
		private final ArcColouringStrategy colouringStrategy;
		private final boolean hasColours;
		private final int[] start;

		private NodeLengthLexComparator(final int[] inFill, final int[][] inFrom, final int[] start, final ArcColouringStrategy colouringStrategy) {
			this.inFill = inFill;
			this.inFrom = inFrom;
			this.start = start;
			this.colouringStrategy = colouringStrategy;
			this.hasColours = colouringStrategy != null;
		}

		@Override
		public int compare(final int x, final int y) {
			final int lx = inFill[x], ly = inFill[y];
			if (lx - ly != 0) return lx - ly;
			int diff, startx, starty;
			for(int i = 0; i < lx; i++) {
				if (hasColours) {
					diff = colouringStrategy.colour(inFrom[x][i], x) - colouringStrategy.colour(inFrom[y][i], y);
					if (diff != 0) return diff;
				}
				startx = start[inFrom[x][i]];
				starty = start[inFrom[y][i]];
				if (startx - starty != 0) return startx - starty;
			}
			return 0;
		}

		public int beginCurrentPart;
		public int endCurrentPart;

		public boolean equal(final int x, final int y) {
			final int lx = inFill[x], ly = inFill[y];
			if (lx - ly != 0) return false;
			boolean sameColour;
			int startx, starty;
			for(int i = 0; i < lx; i++) {
				if (hasColours) {
					sameColour = colouringStrategy.colour(inFrom[x][i], x) == colouringStrategy.colour(inFrom[y][i], y);
					if (! sameColour) return false;
				}
				startx = start[inFrom[x][i]];
				starty = start[inFrom[y][i]];
				if (startx >= beginCurrentPart && startx < endCurrentPart) startx = beginCurrentPart;
				if (starty >= beginCurrentPart && starty < endCurrentPart) starty = beginCurrentPart;
				if (startx - starty != 0) return false;
			}
			return true;
		}
	}

	/** Returns a labelling of an immutable graph such that two nodes have the same label iff they
	 * are in the same fibre of minimal fibrations.
	 *
	 * <p>Note that the labelling is surjective&mdash;if a node
	 * has label <var>k</var>, there are nodes with label <var>j</var>, for every 0&le;<var>j</var>&le;<var>k</var>.
	 *
	 * @param g an immutable graph.
	 * @param nodeColouring a colouring for the nodes, or {@code null}.
	 * @param arcColouring a colouring for the arcs, or {@code null}.
	 * @return an array of integers labelling the graph so that two nodes have the same label iff they
	 * are in the same fibre of minimal fibrations.
	 */

	public static int[] compute(final ImmutableGraph g, final NodeColouringStrategy nodeColouring, final ArcColouringStrategy arcColouring) {
		final int n = g.numNodes();

		// Precomputation of indegrees (for allocating the colour/part lists).
		final int[] inFill = new int[n];
		int[] succ;
		int d;

		final NodeIterator nodeIterator = g.nodeIterator();
		for(int i = 0; i < n; i++) {
			nodeIterator.nextInt();
			d = nodeIterator.outdegree();
			if (d == 0) continue;
			succ = nodeIterator.successorArray();
			while(d-- != 0) inFill[succ[d]]++;
		}

		// Allocation of colour/part lists.
		final int[][] inFrom = new int[n][];
		for(int i = n; i-- != 0;) inFrom[i] = new int[inFill[i]];

		Arrays.fill(inFill, 0);

		/* Parts array: a permutation of nodes. Each entry represent a node in the part.
		 * Parts are contiguous. A sentinel at index n is added to avoid special cases. */
		final int[] part = new int[n + 1];
		for(int i = n + 1; i-- != 0;) part[i] = i;
		/* The start of the part a node (the index) belongs to. A sentinel at index n
		 * is added to avoid special cases. More precisely, start[x] is the index in
		 * part[] where the part containing x starts. */
		final int[] start = new int[n + 1];
		start[n] = Integer.MAX_VALUE;

		// The number of active parts in the current round.
		int numActiveParts = 1;
		// The list of active parts (initially, part 0).
		final int[] startActivePart = new int[n];

		// Whether a part starting at a given index has been touched by the current round.
		final boolean[] touched = new boolean[n];
		// The list of parts that have been touched during the current round.
		final int[] touchedList = new int[n];
		int touchedListLength = 0;

		int s, x, y;
		LazyIntIterator successors;

		final NodeLengthLexComparator nodeLengthLexComparator = new NodeLengthLexComparator(inFill, inFrom, start, arcColouring);
		final ColourPartComparator colourPartComparator = new ColourPartComparator(start, arcColouring);

		int overallMaxLength = n, overallMinLength = n, numSingletons = 0, numParts = 1;

		for(int k = 0; k < n; k++) {
			LOGGER.info("Starting phase " + k + " [parts=" + numParts + ", active=" + numActiveParts + ", size=" + overallMinLength + " -> " + overallMaxLength + ", {*}=" + numSingletons + "]...");
			overallMaxLength = -1;
			overallMinLength = Integer.MAX_VALUE;
			numSingletons = 0;

			// First phase: we run through the active parts and follow outlinks.
			for(int p = numActiveParts; p-- != 0;) {
				s = startActivePart[p];
				for(int j = s; start[part[j]] == s; j++) {
					x = part[j];
					d = g.outdegree(x);
					successors = g.successors(x);
					while(d-- != 0) {
						y = successors.nextInt();
						inFrom[y][inFill[y]++] = x;

						// If the partition y belongs to has never been touched, we add it to the touched list.
						if (! touched[start[y]]) {
							touched[start[y]] = true;
							touchedList[touchedListLength++] = start[y];
						}
					}
				}
			}

			LOGGER.info("Touched: " + touchedListLength);

			if (ASSERTS) {
				for(int i = n; i-- != 0;) {
					// At the first stage, the number of active predecessors must be equal to the number of predecessors.
					if (k == 0) assert inFill[i] == inFrom[i].length;
					// Node with a positive number of predecessor must belong to a touched partition.
					assert inFill[i] == 0 || touched[start[i]] :
						"Node " + i + " has inFill " + inFill[i] + " but its partition (" + start[i] + ") has not been touched";
				}
			}

			numActiveParts = 0;

			IntArrays.quickSort(touchedList, 0, touchedListLength);

			// Now we do the final pass: we examine each part that has been touched.
			for(int j = 0; j < touchedListLength; j++) {
				final int begin = touchedList[j];
				int end;
				// We find the end of the part; in the meantime, we sort the lists.
				for(end = begin; start[part[end]] == begin; end++) {
					x = part[end];
					colourPartComparator.targetNode = x;
					IntArrays.quickSort(inFrom[x], 0, inFill[x], colourPartComparator);
				}

				if (begin + 1 == end) {
					 // We need no processing for singletons.
					inFill[part[begin]] = 0;
					touched[begin] = false;
					continue;
				}

				// Now we sort the nodes in the current part by comparing the sequence of coloured inlinks.
				IntArrays.quickSort(part, begin, end, nodeLengthLexComparator);

				if (ASSERTS) for(int i = begin; i < end; i++) assert start[part[i]] == begin;

				/* Now we go through the sorted partition, identifying partition borders. We
				 keep track of some statistical data, and of the first largest part. */
				int maxLength = -1, currLength, maxStart = begin, currBegin = begin;

				/* Nodes in the current part are to be considered equal by the comparator. */
				nodeLengthLexComparator.beginCurrentPart = begin;
				nodeLengthLexComparator.endCurrentPart = end;

				for(int l = begin; l < end - 1; l++) {
					x = part[l];
					y = part[l + 1];
					start[x] = currBegin;

					if (nodeLengthLexComparator.equal(x, y)) continue;

					currLength = l + 1 - currBegin;
					if (currLength > maxLength) {
						maxStart = currBegin;
						maxLength = currLength;
					}
					currBegin = l + 1;
					numParts++;

					// Stats
					if (currLength < overallMinLength) overallMinLength = currLength;
					if (currLength > overallMaxLength) overallMaxLength = currLength;
					if (currLength == 1) numSingletons++;
				}
				start[part[end - 1]] = currBegin;
				currLength = end - currBegin;
				if (currLength > maxLength) maxStart = currBegin;

				// Stats
				if (currLength < overallMinLength) overallMinLength = currLength;
				if (currLength > overallMaxLength) overallMaxLength = currLength;
				if (currLength == 1) numSingletons++;

				if (ASSERTS) for(int i = begin; i < end - 1; i++) assert start[part[i]] <= start[part[i + 1]];

				/* We make a final scan through the parts, adding all
				 new parts (except the first largest part) to the active parts list. */
				for(int l = begin; l < end; l++) {
					x = start[part[l]];
					if (x != start[part[l + 1]] && x != maxStart) startActivePart[numActiveParts++] = x;
					inFill[part[l]] = 0;
				}

				touched[begin] = false;
			}

			if (ASSERTS) {
				for(int i = n; i-- != 0;) assert ! touched[i];
				for(int i = n; i-- != 0;) assert inFill[i] == 0;
			}
			if (numActiveParts == 0) break;
			touchedListLength = 0;
		}

		// We renumber parts from 0 and put the result into inFill.
		x = s = 0;
		inFill[part[0]] = 0;
		for(int i = 1; i < n; i++) {
			y = part[i];
			if (start[y] != s) {
				x++;
				s = start[y];
			}
			inFill[y] = x;
		}

		return inFill;
	}

	public static void main(final String arg[]) throws IOException {
		if (arg.length == 1) System.out.println(IntArrayList.wrap(compute(ImmutableGraph.load(arg[0]), null, null)));
		else if (arg.length == 2) BinIO.storeInts(compute(ImmutableGraph.load(arg[0]), null, null), arg[1]);
		else if (arg.length == 3) {
			final int[] d = BinIO.loadInts(arg[1]);
			BinIO.storeInts(compute(ImmutableGraph.load(arg[0]), null, (x, y) -> d[x]), arg[2]);
		}
		else throw new IllegalArgumentException();
	}
}
