package mit.ai.nl.recognition;

import java.io.*;
import mit.ai.nl.core.*;
import mit.ai.nl.database.Dataset;

import Jama.Matrix;
import Jama.SingularValueDecomposition;
import com.sun.java.util.collections.*;

public final class TestModel extends Model implements Serializable{

  // Convenient Constants
  private static final int X = Stroke.X;
  private static final int Y = Stroke.Y;
  static final long serialVersionUID = 3944676799961672743L;

  // Renormalization constants
  private static final float DSCALE = 3e-3f;
  private static final float FLIPTHRESHOLD = 0.37f;
  private static final int NUMPOINTS = 36; // Number of points to resample
  private static final int DX_INDEX = NUMPOINTS*2;
  private static final int DY_INDEX = DX_INDEX+1;
  private static final int AR_INDEX = DY_INDEX+1;
  static final int FULL_DIMENSIONALITY = AR_INDEX+1;

  int numStrokes;
  Matrix proj;
  GaussModel[] strokeModels;
  GaussModel[] geoModels;

  ///////////////////////////////////////////////////////////////////////
  public GaussModel[] getmodels(boolean b){
    return (b) ? strokeModels : geoModels;
  }

  ///////////////////////////////////////////////////////////////////////
  public String toString(){
    return Char.getChar(chars[0]).description;
  }

  ///////////////////////////////////////////////////////////////////////
  public float[][][] meanChar(){
    float[][][] result = new float[numStrokes][][];

    for(int i=0; i<numStrokes; i++){

      // get the stroke
      float[][] s = new float[2][NUMPOINTS]; 
      Matrix m = strokeModels[i].mean;
      if(proj != null)
	m = proj.transpose().times(m);
      for(int j=0; j<NUMPOINTS; j++){
	s[X][j] = (float) m.get(j, 0);
	s[Y][j] = (float) m.get(j+NUMPOINTS, 0);
      }

      //align its geometry
      if(i > 0){
	

      }
      
      result[i] = s;
    }

    return result;
  }

  ///////////////////////////////////////////////////////////////////////
  public double dist(TestModel mod) {
    if(numStrokes != mod.numStrokes)
      return Double.MAX_VALUE;
    double val = 0;
    for(int i=0; i<numStrokes; i++)
      val += strokeModels[i].dist(mod.strokeModels[i]);
    return val;
  }

  ///////////////////////////////////////////////////////////////////////
  public boolean isSimilarTo(TestModel mod, double thresh) {
    if(numStrokes != mod.numStrokes)
      return false;
    for(int i=0; i<numStrokes; i++)
      if(strokeModels[i].dist(mod.strokeModels[i]) > thresh)
	return false;
    return true;
  }
  
  ///////////////////////////////////////////////////////////////////////
  /** This constructor creates a TestModel from two other TestModels,
   *  without harming them in the process. Please don't call on two
   *  models whose stroke numbers don't match, two models with
   *  different projection matrices, or two models whose characters
   *  don't match, for that matter.  */
  public TestModel(TestModel mod1, TestModel mod2){
    proj = mod1.proj;
    chars = mod1.chars;
    numStrokes = mod1.numStrokes;
    strokeModels = new GaussModel[numStrokes];
    geoModels = new GaussModel[numStrokes-1];
    for(int i=0; i<strokeModels.length; i++)
      strokeModels[i] = new GaussModel(mod1.strokeModels[i],
				       mod2.strokeModels[i]);

    for(int i=0; i<geoModels.length; i++)
      geoModels[i] = new GaussModel(mod1.geoModels[i],
				    mod2.geoModels[i]);
  }

  ///////////////////////////////////////////////////////////////////////
  /** The primary constructor for a TestModel.  This is a generally ugly
   *  function which has all sorts of variables floating around.  It
   *  should make sense if you understand the following. Strokes are
   *  flipped in groups, depending on whether the majority of strokes
   *  in the group should be flipped --- thus, if the majority of
   *  first strokes in a symbol are found to need flipping, then *all*
   *  of them are flipped. Likewise, strokes are made order invarient
   *  as a group as well, based on the average angles of all
   *  strokes. */
  public TestModel(Dataset d, char[] characters) {
    System.out.print("Making model of "+ d.initials + "'s " 
		     + Char.getChar(d.character)+"...");
    // some initial setup
    int numPoints = d.size();
    chars = characters;
    numStrokes = d.numStrokes;
    strokeModels = new GaussModel[numStrokes];
    geoModels = new GaussModel[numStrokes-1];
    
    // first, determine whether strokes should be flipped
    boolean[] flip = new boolean[numStrokes];
    for(int i=0; i<numStrokes; i++){
      int count = 0;
      for(int j=0; j<numPoints; j++){
	float[][][] current = d.elementAt(j);
	if(toFlip(current[i], new Box(current[i])))
	  count++;
      }
      flip[i] = ((1d*count)/numPoints > 0.5);
    }

    // next, load the strokes
    Stroke[][] strokes = new Stroke[numStrokes][numPoints];
    double[] angles = new double[numStrokes];
    double[] scales = new double[numPoints];
    {
      Stroke[] tempstrokes = new Stroke[numStrokes];
      for(int i=0; i<numPoints; i++){
	
	float[][][] current = d.elementAt(i);
	for(int j=0; j<numStrokes; j++){
	  float[][] next = current[j];
	  if(flip[j]){
	    next = Stroke.clonePoints(current[j]);
	    Stroke.flip(next);
	  }
	  strokes[j][i] = tempstrokes[j] = new Stroke(next);
	}
	
	if(numStrokes > 1){
	  Box bound = new Box(tempstrokes);
	  scales[i] = Math.sqrt(bound.getArea());
	  float lx = bound.getLx(), ly = bound.getLy();
	  for(int j=0; j<numStrokes; j++){
	    int pc = tempstrokes[j].getPointCount()-1;
	    float x = tempstrokes[j].getX(pc)-lx;
	    float y = tempstrokes[j].getY(pc)-ly;
	    angles[j] += (x==0 && y==0) ? 0 : y/Math.sqrt(x*x+y*y);
	  }
	}
      }
    }

    // now we can deal with stroke order invariance
    Util.quicksort(strokes, angles);  

    // make the models
    Matrix[] spoints = new Matrix[numPoints];
    Matrix[] gpoints = new Matrix[numPoints];

    for(int j=0; j<numStrokes; j++){
      for(int i=0; i<numPoints; i++){
	spoints[i] = processStroke(strokes[j][i]);
	if(j!=0)
	  gpoints[i] = processGeo(strokes[0][i], strokes[j][i], scales[i]);
      }
      strokeModels[j]  = new GaussModel(spoints);
      if(j!=0)
	geoModels[j-1] = new GaussModel(gpoints);
    }
    
    System.out.println("Done.");
  }

  ///////////////////////////////////////////////////////////////////////
  public double density(Object token){

    // Take care of the trival cases
    Matrix[] vals = (Matrix []) token;
    if (vals.length != 2*numStrokes-1)
      return 0;
    
    double dens = 1;
    for(int i=0; i<numStrokes; i++)
      dens *= strokeModels[i].density(vals[i]);
    for(int i=0; i<numStrokes-1; i++)
      dens *= geoModels[i].density(vals[i+numStrokes]);
    return dens;
  }

  ///////////////////////////////////////////////////////////////////////
  public Object process(Stroke[] input){
    Matrix[] vals = new Matrix[2*input.length-1];
    float[][] lastpoints = new float[2][input.length];

    // Set up the strokes
    for(int i=0; i<input.length; i++) {
      Matrix s = (Matrix) input[i].getToken();
      if(s == null){
	s = processStroke(input[i]);
	if(proj != null)
	  s = proj.times(s);
	input[i].setToken(s);
      }
      vals[i] = s;
    }

    if(input.length == 1)
      return vals;

    // Take care of stroke order invariance
    Box bound = new Box(input);
    float lx = bound.getLx(), ly = bound.getLy();
    double[] angles = new double[input.length];
    for(int i=0; i<input.length; i++) {
      float[][] points = input[i].getPoints();
      int last = points[X].length-1;
      boolean toflip = toFlip(points, input[i]);
      float x = points[X][toflip ? 0 : last]-lx;
      float y = points[Y][toflip ? 0 : last]-ly;
      angles[i] = (x==0 && y==0) ? 0 : y/Math.sqrt(x*x+y*y);
    }

    Stroke[] input2 = (Stroke[]) input.clone();
    double[] angles2 = (double[]) angles.clone();
    Util.quicksort(input2, angles2);
    Util.quicksort(vals, angles, 0, angles.length-1);

    // Set up the geometric relations
    double scale = Math.sqrt(bound.getArea());
    for(int i=0; i<input2.length-1; i++) 
      vals[input.length+i] = processGeo(input2[0], input2[i+1], scale);
    
    return (Object) vals;
  }

  ///////////////////////////////////////////////////////////////////////
  private static Matrix processGeo(Stroke s1, Stroke s2, double scale) {
    Matrix m = new Matrix(3,1);
    m.set(0, 0, (s2.getCx() - s1.getCx())/scale);
    m.set(1, 0, (s2.getCy() - s1.getCy())/scale);
    m.set(2, 0, Util.ratioMap(s1.getLength(), s2.getLength()));
    return m;
  }

  private static Matrix processStroke(Stroke s){
    float[][] points = s.getPoints();
    if(toFlip(points, s))
      Stroke.flip(points);
    Stroke.shift(points, s.getCentroidX(), s.getCentroidY());

    if(s.getLength() != 0)
      Stroke.scale(points, 1f/s.getLength());
    
    float delta = Stroke.resample(points, NUMPOINTS);

    // create the results array
    double[] results = new double[NUMPOINTS * 2 + 3];
    for(int i=0; i<NUMPOINTS; i++) {
      results[i] = points[X][i];
      results[i+NUMPOINTS] = points[Y][i];
    }

    results[DX_INDEX] = DSCALE*getDerivative(points[X], delta);
    results[DY_INDEX] = DSCALE*getDerivative(points[Y], delta);
    results[AR_INDEX] = s.getAspectRatio();

    return new Matrix(results, NUMPOINTS*2+3);
  }

  ////////////////////////////////////////////////////////////////////
  /** toFlip() is a heuristic for determining whether a stroke should 
   *  flip or not. points is a floating point array of points, */
  private static final boolean toFlip(float[][] points, Box b) {
    boolean toflip, flipx, flipy;
    float[] pointsX = points[X], pointsY=points[Y];
    int pc = pointsX.length;
    float diag;
    {
      float temp = b.getHeight();
      diag = temp*temp;
      temp = b.getWidth();
      diag += temp * temp;
      diag = (float) Math.sqrt(diag);
    }
    float diffX = (pointsX[pc-1] - pointsX[0])/diag;
    float diffY = (pointsY[pc-1] - pointsY[0])/diag;
    flipx = diffX < 0;
    flipy = diffY < 0;
    diffX = flipx ? -diffX : diffX;
    diffY = flipy ? -diffY : diffY;
    if((diffX < FLIPTHRESHOLD) && (diffY < FLIPTHRESHOLD))
      toflip=false;
    else if((diffX > FLIPTHRESHOLD) && (diffY > FLIPTHRESHOLD))
      toflip = flipy;
    else
      toflip = (diffX > diffY) ? flipx : flipy;
    return toflip;
  }

  ///////////////////////////////////////////////////////////////////////
  /** getDerivative() returns the sum-squared derivative of the function
   *  in funct, sampled at equal values delta units apart. It destroys
   *  the values in funct, though!  */
  private static float getDerivative(float[] funct, float delta){
    // Derivitives of singularities don't make sense...
    if (delta <= 0 || funct.length <= 1)
      return 0;
    int N = funct.length;
    float delta2 = 2*delta;
    float[] deriv = new float[N];

    for(int i = 0; i<2; i++){
      // use forward differencing for point 0
      deriv[0] = (funct[1]-funct[0])/delta;

      // use backwards differencing for point n
      for(int j = 1; j<N; j++)
	deriv[j] = (funct[j]-funct[j-1])/delta;
      
      float[] temp = funct;
      funct = deriv;
      deriv = temp;
    }

    float result=0;
    for(int i=0; i<N; i++)
      result += funct[i]; //*funct[i];
    return result;
  }

}
