package jprofilegrid.calculations;

import java.util.Hashtable;
import java.util.Iterator;
import java.util.List;
import java.util.Vector;

import jprofilegrid.constants.AlignmentConstants;
import jprofilegrid.constants.AminoAcidConstants;
import jprofilegrid.model.AnalysisOptions;
import jprofilegrid.model.UnknownSymbol;

public class MultipleSequenceAnalysis
{
	private Vector<Vector<Double>> correctedWindowSizes;
	private int[] motiffEndpointIndices, variableEndpointIndices;
	private int windowLength;
	private Vector<UnknownSymbol> unkSymbolInSeq;
	private String weightedString;

	private MultipleSequenceAnalysis( Vector<Vector<Double>>  nUncorrectedWindowSizes, int nWindowLength,
			Vector<Vector<Double>> nCorrectedWindowSizes, int[] motiffEndpoints, int[] nVariableEndpointIndices,
			Vector<UnknownSymbol> nUnknownSymbolsInSequences, String nWeightedString)
	{
		windowLength = nWindowLength;
		correctedWindowSizes = nCorrectedWindowSizes;
		motiffEndpointIndices = motiffEndpoints;
		variableEndpointIndices = nVariableEndpointIndices;
		unkSymbolInSeq = nUnknownSymbolsInSequences;
		weightedString = nWeightedString;
	}

	public Vector<Vector<Double>> getWindowSizes()
	{
		return correctedWindowSizes;
	}

	public int getWindowLength()
	{
		return windowLength;
	}

	public Vector<Vector<Double>> getCorrected()
	{
		return correctedWindowSizes;
	}

	public int[] getMotiffEndpointIndices()
	{
		return motiffEndpointIndices;
	}

	public int[] getVariableEndpointIndices()
	{
		return variableEndpointIndices;
	}

	private static MultipleSequenceAlignment sampleAlignment(MultipleSequenceAlignment multipleSequenceAlignment, int numSamples)
	{
		Vector<Sequence> sampledSequences = new Vector<Sequence>();

		int seqCount = multipleSequenceAlignment.getSequences().size();

		int rows;
		double stepSize = 1.0;
		if(seqCount < numSamples)
		{
			rows     = seqCount;
		}
		else
		{
			rows     = numSamples;
			stepSize = seqCount / ((double)numSamples);
		}

		double nextSeq = 0;
		for(int i = 0; i < rows; i++)
		{
			int index = (int)nextSeq;

			if(nextSeq == seqCount)
				index = (int)nextSeq - 1;

			Sequence s = multipleSequenceAlignment.getSequences().get(index);
			sampledSequences.add(s);
			 nextSeq += stepSize;
		}

		return new MultipleSequenceAlignment(multipleSequenceAlignment.getHeader(), sampledSequences);
	}

	public static void analyze( MultipleSequenceAlignment original, AnalysisOptions analysisOptions)
	{
		int windowLength = analysisOptions.windowSize;
		double threshold = analysisOptions.threshold;
		String weightedString = analysisOptions.weightedString;
		AlignmentConstants constants = analysisOptions.alignmentConstants;

		// Compute column values.
		// O(numberOfSequences)
		Vector<Double> columnValues = new Vector<Double>();
		Vector<Integer> gapsPerColumn = new Vector<Integer>();

		int originalNumberOfSequences = original.getSequences().size();
		int numberOfSequences = Math.min((int)(analysisOptions.similarityFraction * originalNumberOfSequences), originalNumberOfSequences);
		if(numberOfSequences <= 0)
			numberOfSequences = 1;
		MultipleSequenceAlignment alignment = sampleAlignment(original, numberOfSequences);

		int minimumLength = alignment.getMinSequenceLength();

		// O(lengthOfSequences * numberOfSequences * log( numberOfSequences ) )
		Hashtable <String, UnknownSymbol> unknownSymbols = new Hashtable <String, UnknownSymbol>();
		List<Sequence> sequences = alignment.getSequences();

		for( int columnNumber = 0; columnNumber < minimumLength; columnNumber++ )
		{
			int gapsInCurrentColumn = 0;
			double columnSum = 0;
			for( int i = 0; i < numberOfSequences; i++ )
			{
				Sequence sequence = sequences.get(i);
				String aa1 = sequence.getAminoAcid(columnNumber);
				double aa1w = sequence.getWeight();
				String currentSymbol = sequence.getAminoAcid(columnNumber);
				if(!constants.isSymbolDefined(currentSymbol))
				{
					String sequenceName = sequence.getName();
					if(unknownSymbols.containsKey(currentSymbol) )
						unknownSymbols.get(currentSymbol).addUnknownSymbolSpeciesAndLocation(
								sequenceName,
								columnNumber + 1);
					else
						unknownSymbols.put(currentSymbol, new UnknownSymbol(
								currentSymbol,
								sequenceName,
								columnNumber + 1));
				}

				for( int j = i+1; j < numberOfSequences; j++ )
				{
					Sequence sequence2 = sequences.get(j);

					String aa2 = sequence2.getAminoAcid(columnNumber);
					double aa2w = sequence2.getWeight();
					columnSum += AminoAcidConstants.getMatrixValue( aa1, aa2 ) * aa2w;
					columnSum += AminoAcidConstants.getMatrixValue( aa2, aa1 ) * aa1w;
				}
			}

			gapsPerColumn.add(gapsInCurrentColumn);

			columnValues.add(columnSum);
		}

		// Compute window values.
		Vector <Vector <Double>> ucWS = new Vector<Vector <Double>>();
		ucWS.add(null);
		for( int k = 1; k <= windowLength; k++ )
		{
			Vector <Double> uncorrectedSimilarities = new Vector<Double>();
			int leftColumn = -(k / 2);
			while( leftColumn < minimumLength - (k / 2))
			{
				double plotcon = 0;
				if( ! (leftColumn < 0 || leftColumn + k > minimumLength) )
					for( int i = leftColumn; i < leftColumn + k; i++ )
						plotcon += columnValues.get( i );
				double divisor = numberOfSequences * k * (numberOfSequences - 1) * k;
				double currentValue = plotcon / divisor;
				uncorrectedSimilarities.add(currentValue);
				leftColumn++;
			}
			ucWS.add(uncorrectedSimilarities);
		}

		Vector<UnknownSymbol> allUnknownSymbols = new Vector<UnknownSymbol>();
		Iterator<UnknownSymbol> it  = unknownSymbols.values().iterator();
		while(it.hasNext())
			allUnknownSymbols.add(it.next());

		MultipleSequenceAnalysis multipleSequenceAnalysis =
			performCorrections( alignment, ucWS, gapsPerColumn, numberOfSequences, windowLength, minimumLength, threshold, allUnknownSymbols, weightedString);

		original.setMultipleSequenceAnalysis(multipleSequenceAnalysis);
	}

	private static MultipleSequenceAnalysis performCorrections( MultipleSequenceAlignment alignment,
			Vector<Vector<Double>> ucWS, Vector<Integer> gapsPerColumn, int numberOfSequences, int windowLength,
			int minimumLength, double threshold, Vector<UnknownSymbol> unknownSymbolsInSequences,
			String weightedString)
	{
		// This where two corrections are applied:
		// 1. Makes the graph bottom at zero.
		// 2. Correcting for the # of gap sequences.
		Vector <Vector<Double>> correctedWindowSizes = new Vector<Vector<Double>>();
		for( int i = 0; i < ucWS.size(); i++ )
			correctedWindowSizes.add(new Vector<Double>());
		int maximumOccuringAminoAcidInColumn = 0;
		for( int k = 1; k <= windowLength; k++ )
		{
			for( int i = 0; i < ucWS.get(k).size(); i++ )
				correctedWindowSizes.get(k).add(ucWS.get(k).get(i));

			// Zeros graph.
			double minValue = Double.MAX_VALUE;
			for( int i = 0; i < correctedWindowSizes.get(k).size(); i++ )
				if( correctedWindowSizes.get(k).get(i) < minValue )
					minValue = correctedWindowSizes.get(k).get(i);

			for( int i = 0; i < correctedWindowSizes.get(k).size(); i++ )
				correctedWindowSizes.get(k).set(i, correctedWindowSizes.get(k).get(i)-minValue);

			// Performs gap correction.
			for( int i = 0; i < correctedWindowSizes.get(k).size(); i++ )
			{
				correctedWindowSizes.get(k).set(i, correctedWindowSizes.get(k).get(i) *(
						((double)numberOfSequences - gapsPerColumn.get( i ))/numberOfSequences));
			}

			// Normalizes graph based upon the highest column conservance.
			double maxValue = Double.MIN_VALUE;
			int maxColumnIndex = 0;
			for( int i = 0; i < correctedWindowSizes.get(k).size(); i++ )
				if( correctedWindowSizes.get(k).get(i) > maxValue )
				{
					maxValue = correctedWindowSizes.get(k).get(i);
					//
					maxColumnIndex = i;
				}

			Hashtable <String, Integer> aminoAcidLookupTable = new Hashtable <String,Integer>();
			Vector <Integer> aminoAcidCounts = new Vector<Integer>();
			Vector <String> aminoAcidNames = new Vector<String>();
			for( int i = 0; i < numberOfSequences; i++)
			{
			//	for( int j = maxColumnIndex - (k / 2); j <= (maxColumnIndex + (k/2)); j++ )
			//	{
			//		if( j >= 0 && j < minimumLength)
			//		{
						String currentAminoAcid = alignment.getSequences().get(i).getAminoAcid(maxColumnIndex);
						if(!aminoAcidLookupTable.containsKey(currentAminoAcid))
						{
							aminoAcidLookupTable.put(currentAminoAcid, aminoAcidCounts.size());
							aminoAcidCounts.add(1);
							aminoAcidNames.add(currentAminoAcid);
						}
						else
							aminoAcidCounts.set(aminoAcidLookupTable.get(currentAminoAcid),
									aminoAcidCounts.get(aminoAcidLookupTable.get(currentAminoAcid))+ 1);
				//	}
				//}
			}

			int maximumOccurence = 0;
			int indexOfMaximalOccurence = 0;
			for( int i = 0; i < aminoAcidCounts.size(); i++ )
			{
				if( aminoAcidCounts.get(i) > maximumOccurence)
				{
					maximumOccurence = aminoAcidCounts.get(i);
					indexOfMaximalOccurence = i;
				}
			}

		//	if( k == 1 )
				maximumOccuringAminoAcidInColumn =  aminoAcidCounts.get(indexOfMaximalOccurence);
		//		System.out.println(maximumOccuringAminoAcidInColumn);
			double normalizationConstant = (1.0 /*(double)maximumOccuringAminoAcidInColumn / (double)numberOfSequences)*/
								/ correctedWindowSizes.get(k).get(maxColumnIndex));

		//	if( k == 1 || k== 9)

			for( int i = 0; i < correctedWindowSizes.get(k).size(); i++ )
				correctedWindowSizes.get(k).set(i, correctedWindowSizes.get(k).get(i) * normalizationConstant );
			for( int i = 0; i < correctedWindowSizes.get(k).size(); i++ )
			{
				correctedWindowSizes.get(k).set(i, correctedWindowSizes.get(k).get(i) *
						((double)maximumOccuringAminoAcidInColumn / (double)numberOfSequences));
			}

		}

		return( calculateMotiffEndpoints( ucWS, windowLength,
				correctedWindowSizes, threshold, unknownSymbolsInSequences, weightedString ) );
	}

	private static MultipleSequenceAnalysis calculateMotiffEndpoints( Vector<Vector<Double>> ucWS, int windowLength,
			Vector<Vector<Double>> cWS, double threshold, Vector<UnknownSymbol> unknownSymbolsInSequences,
			String weightedString )
	{
		Vector<Double> columnValuesInDefaultWindow = cWS.get(windowLength);
		Vector<Integer> motiffEndpoints = new Vector<Integer>();
		Vector<Integer> variableEndpoints = new Vector<Integer>();
		boolean leftEndpointSearch = true;
		for( int i = 0; i < columnValuesInDefaultWindow.size(); i++ )
		{
			if( leftEndpointSearch )
			{
				if( columnValuesInDefaultWindow.get(i) > threshold )
				{
					motiffEndpoints.add(i);
					leftEndpointSearch = false;
				}
			}
			else
				if( columnValuesInDefaultWindow.get(i) < threshold )
				{
					motiffEndpoints.add(i);
					leftEndpointSearch = true;
				}
		}
		// If the sequence ends in a motiff, close the motiff at the end.
		if( motiffEndpoints.size() % 2 == 1 )
			motiffEndpoints.add(cWS.get(windowLength).size() - 1);

		for( int i = 0; i < motiffEndpoints.size(); i++ )
		{
			if( i == 0 )
			{
				if(  motiffEndpoints.get(0) != 0 )
				{
					variableEndpoints.add(0);
					variableEndpoints.add(motiffEndpoints.get(i) - 1 );
				}
			}
			else
				if( motiffEndpoints.get(i - 1) + 1 != motiffEndpoints.get(i) && i % 2 == 0 )
					variableEndpoints.add(motiffEndpoints.get(i) - 1 );
				else
				{
					if( i < motiffEndpoints.size() - 1)
					{
						if( motiffEndpoints.get(i + 1) - 1 != motiffEndpoints.get(i) && i % 2 == 1 )
							variableEndpoints.add(motiffEndpoints.get(i) + 1 );
					}
					else
						if(i == motiffEndpoints.size() - 1)
						{
							if(  motiffEndpoints.get(i) != cWS.get(windowLength).size() -1 )
							{
								variableEndpoints.add(motiffEndpoints.get(i) + 1);
								variableEndpoints.add(cWS.get(windowLength).size() -1);
							}
						}
				}
		}

		// If the sequence ends in a variable region, close the variable region at the end.
		if( variableEndpoints.size() % 2 == 1 )
			variableEndpoints.add(cWS.get(windowLength).size() - 1);

		int[] motiffEndpointsIntArray = new int[motiffEndpoints.size()];
		for( int i = 0; i < motiffEndpoints.size(); i++ )
			motiffEndpointsIntArray[i] = motiffEndpoints.get(i);

		int[] variableEndpointsIntArray = new int[variableEndpoints.size()];
		for( int i = 0; i < variableEndpoints.size(); i++ )
			variableEndpointsIntArray[i] = variableEndpoints.get(i);
		return( new MultipleSequenceAnalysis( ucWS, windowLength, cWS, motiffEndpointsIntArray,
				variableEndpointsIntArray, unknownSymbolsInSequences, weightedString ) );
	}

	public Vector<UnknownSymbol> getUnknownSymbols()
	{
		return unkSymbolInSeq;
	}

	public String getWeightedString()
	{
		return weightedString;
	}
}