mev-4.0.01/source/org/tigr/microarray/mev/cluster/algorithm/impl/SVM.java

Code
Comments
Other
Rev Date Author Line
2 26 Feb 07 jari 1 /*
2 26 Feb 07 jari 2 Copyright @ 1999-2003, The Institute for Genomic Research (TIGR).
2 26 Feb 07 jari 3 All rights reserved.
2 26 Feb 07 jari 4 */
2 26 Feb 07 jari 5 /*
2 26 Feb 07 jari 6  * $RCSfile: SVM.java,v $
2 26 Feb 07 jari 7  * $Revision: 1.3 $
2 26 Feb 07 jari 8  * $Date: 2005/03/10 15:45:20 $
2 26 Feb 07 jari 9  * $Author: braistedj $
2 26 Feb 07 jari 10  * $State: Exp $
2 26 Feb 07 jari 11  */
2 26 Feb 07 jari 12
2 26 Feb 07 jari 13 package org.tigr.microarray.mev.cluster.algorithm.impl;
2 26 Feb 07 jari 14
2 26 Feb 07 jari 15 import java.util.ArrayList;
2 26 Feb 07 jari 16 import java.util.Arrays;
2 26 Feb 07 jari 17 import java.util.Random;
2 26 Feb 07 jari 18
2 26 Feb 07 jari 19 import javax.swing.JOptionPane;
2 26 Feb 07 jari 20
2 26 Feb 07 jari 21 import org.tigr.microarray.mev.cluster.Cluster;
2 26 Feb 07 jari 22 import org.tigr.microarray.mev.cluster.Node;
2 26 Feb 07 jari 23 import org.tigr.microarray.mev.cluster.NodeList;
2 26 Feb 07 jari 24 import org.tigr.microarray.mev.cluster.NodeValue;
2 26 Feb 07 jari 25 import org.tigr.microarray.mev.cluster.NodeValueList;
2 26 Feb 07 jari 26 import org.tigr.microarray.mev.cluster.algorithm.AbortException;
2 26 Feb 07 jari 27 import org.tigr.microarray.mev.cluster.algorithm.AbstractAlgorithm;
2 26 Feb 07 jari 28 import org.tigr.microarray.mev.cluster.algorithm.AlgorithmData;
2 26 Feb 07 jari 29 import org.tigr.microarray.mev.cluster.algorithm.AlgorithmEvent;
2 26 Feb 07 jari 30 import org.tigr.microarray.mev.cluster.algorithm.AlgorithmException;
2 26 Feb 07 jari 31 import org.tigr.microarray.mev.cluster.algorithm.AlgorithmParameters;
2 26 Feb 07 jari 32 import org.tigr.util.FloatMatrix;
2 26 Feb 07 jari 33
2 26 Feb 07 jari 34 /** The SVM class provides the execution code for running
2 26 Feb 07 jari 35  * the SVM algorithm and returning results.
2 26 Feb 07 jari 36  */
2 26 Feb 07 jari 37 public class SVM extends AbstractAlgorithm {
2 26 Feb 07 jari 38     
2 26 Feb 07 jari 39     private static final int POSITIVE_DIAGONAL = 0;
2 26 Feb 07 jari 40     private static final int NEGATIVE_DIAGONAL = 1;
2 26 Feb 07 jari 41     
2 26 Feb 07 jari 42     private boolean stop = false;
2 26 Feb 07 jari 43     
2 26 Feb 07 jari 44     private int function;
2 26 Feb 07 jari 45     private float factor;
2 26 Feb 07 jari 46     private boolean absolute;
2 26 Feb 07 jari 47     
2 26 Feb 07 jari 48     private int number_of_genes;
2 26 Feb 07 jari 49     private int number_of_samples;
2 26 Feb 07 jari 50     private boolean svmGenes = true;
2 26 Feb 07 jari 51     
2 26 Feb 07 jari 52     private FloatMatrix expMatrix;
2 26 Feb 07 jari 53     private boolean seenUnderflow = false;
2 26 Feb 07 jari 54     private float prevObjective;
2 26 Feb 07 jari 55     
2 26 Feb 07 jari 56     //HCL parameters
2 26 Feb 07 jari 57     private boolean calcHCL = false;
2 26 Feb 07 jari 58     private boolean calcGeneHCL = false;
2 26 Feb 07 jari 59     private boolean calcSampleHCL = false;
2 26 Feb 07 jari 60     private int method = 0;
2 26 Feb 07 jari 61     
2 26 Feb 07 jari 62     //Indicates genes vs experiment clustering
2 26 Feb 07 jari 63     private boolean classifyGens = true;
2 26 Feb 07 jari 64     
2 26 Feb 07 jari 65     /**
2 26 Feb 07 jari 66      * Executes and returns results of SVM algorithm.
2 26 Feb 07 jari 67      * @param data holds data and initial parameters
2 26 Feb 07 jari 68      * @throws AlgorithmException
2 26 Feb 07 jari 69      * @return SVM result
2 26 Feb 07 jari 70      */
2 26 Feb 07 jari 71     public AlgorithmData execute(AlgorithmData data) throws AlgorithmException {
2 26 Feb 07 jari 72         
2 26 Feb 07 jari 73         AlgorithmParameters map = data.getParams();
2 26 Feb 07 jari 74         
2 26 Feb 07 jari 75         function = map.getInt("hcl-distance-function", EUCLIDEAN);  //applies only to hcl trees on final classes, svm uses dot prod. on normalized vectors
2 26 Feb 07 jari 76         factor   = map.getFloat("distance-factor", 1.0f);
2 26 Feb 07 jari 77         absolute = map.getBoolean("hcl-distance-absolute", false);
2 26 Feb 07 jari 78         
2 26 Feb 07 jari 79         this.expMatrix = data.getMatrix("experiment");
2 26 Feb 07 jari 80         
2 26 Feb 07 jari 81         number_of_genes   = this.expMatrix.getRowDimension();
2 26 Feb 07 jari 82         number_of_samples = this.expMatrix.getColumnDimension();
2 26 Feb 07 jari 83         
2 26 Feb 07 jari 84         svmGenes = map.getBoolean("classify-genes", true);
2 26 Feb 07 jari 85         float constant = map.getFloat("constant", 0);
2 26 Feb 07 jari 86         float coefficient = map.getFloat("coefficient", 0);
2 26 Feb 07 jari 87         float power = map.getFloat("power", 0);
2 26 Feb 07 jari 88         
2 26 Feb 07 jari 89         this.calcHCL = map.getBoolean("calculate-hcl", false);
2 26 Feb 07 jari 90         this.calcGeneHCL = map.getBoolean("calculate-genes-hcl", false);
2 26 Feb 07 jari 91         this.calcSampleHCL = map.getBoolean("calculate-samples-hcl", false);
2 26 Feb 07 jari 92         
2 26 Feb 07 jari 93         this.method = map.getInt("linkage-method", 0);
2 26 Feb 07 jari 94         
2 26 Feb 07 jari 95         AlgorithmData result = new AlgorithmData();
2 26 Feb 07 jari 96         
2 26 Feb 07 jari 97         boolean isClassify = map.getBoolean("is-classify", true);
2 26 Feb 07 jari 98         if (isClassify) {
2 26 Feb 07 jari 99             FloatMatrix trainingMatrix = data.getMatrix("training");
2 26 Feb 07 jari 100             FloatMatrix weightsMatrix = data.getMatrix("weights");
2 26 Feb 07 jari 101             float[] weights = weightsMatrix.getColumnPackedCopy();
2 26 Feb 07 jari 102             FloatMatrix discriminantMatrix = classify(trainingMatrix, weights, coefficient, constant, power);
2 26 Feb 07 jari 103             
2 26 Feb 07 jari 104             int [][] clusters = new int[2][];
2 26 Feb 07 jari 105             clusters[0] = getPositives(discriminantMatrix);
2 26 Feb 07 jari 106             clusters[1] = getNegatives(discriminantMatrix);
2 26 Feb 07 jari 107             
2 26 Feb 07 jari 108             if(calcHCL){
2 26 Feb 07 jari 109                 //preparation for HCL
2 26 Feb 07 jari 110                 Cluster result_cluster = new Cluster();
2 26 Feb 07 jari 111                 NodeList nodeList = result_cluster.getNodeList();
2 26 Feb 07 jari 112                 int[] features;
2 26 Feb 07 jari 113                 for (int i=0; i<clusters.length; i++) {
2 26 Feb 07 jari 114                     if (stop) {
2 26 Feb 07 jari 115                         throw new AbortException();
2 26 Feb 07 jari 116                     }
2 26 Feb 07 jari 117                     features = clusters[i];
2 26 Feb 07 jari 118                     Node node = new Node(features);
2 26 Feb 07 jari 119                     nodeList.addNode(node);
2 26 Feb 07 jari 120                     node.setValues(calculateHierarchicalTree(features, method, calcGeneHCL, calcSampleHCL));
2 26 Feb 07 jari 121                 }
2 26 Feb 07 jari 122                 result.addCluster("cluster", result_cluster);
2 26 Feb 07 jari 123             }
2 26 Feb 07 jari 124             
2 26 Feb 07 jari 125             result.addMatrix("discriminant", discriminantMatrix);
2 26 Feb 07 jari 126             result.addIntArray("positives", getPositives(discriminantMatrix));
2 26 Feb 07 jari 127             result.addIntArray("negatives", getNegatives(discriminantMatrix));
2 26 Feb 07 jari 128             FloatMatrix means = getMeans(discriminantMatrix);
2 26 Feb 07 jari 129             result.addMatrix("means", means);
2 26 Feb 07 jari 130             result.addMatrix("variances", getVariance(discriminantMatrix, means));
2 26 Feb 07 jari 131         } else {
2 26 Feb 07 jari 132             int[] classes = data.getIntArray("classes");
2 26 Feb 07 jari 133             int seed = map.getInt("seed", 0);
2 26 Feb 07 jari 134             boolean normalize = map.getBoolean("normalize", false);
2 26 Feb 07 jari 135             boolean radial = map.getBoolean("radial", false);
2 26 Feb 07 jari 136             float widthFactor = map.getFloat("width-factor", 1.0f);
2 26 Feb 07 jari 137             float positiveDiagonal = map.getFloat("positive-diagonal", 0.0f);
2 26 Feb 07 jari 138             float negativeDiagonal = map.getFloat("negative-diagonal", 0.0f);
2 26 Feb 07 jari 139             float diagonalFactor = map.getFloat("diagonal-factor", 0.0f);
2 26 Feb 07 jari 140             float positiveConstraint = map.getFloat("positive-constraint", 1.0f);
2 26 Feb 07 jari 141             float negativeConstraint = map.getFloat("negative-constraint", 1.0f);
2 26 Feb 07 jari 142             float convergenceThreshold = map.getFloat("convergence-threshold", 0.00001f);
2 26 Feb 07 jari 143             boolean constrainWeights = map.getBoolean("constrain-weights", true);
2 26 Feb 07 jari 144             float[] weights = train(expMatrix, classes, seed, normalize, radial, coefficient, constant, power, widthFactor, positiveDiagonal, negativeDiagonal, diagonalFactor, positiveConstraint, negativeConstraint, convergenceThreshold, constrainWeights);
2 26 Feb 07 jari 145             result.addMatrix("weights", new FloatMatrix(weights, 1));
2 26 Feb 07 jari 146         }
2 26 Feb 07 jari 147         return result;
2 26 Feb 07 jari 148     }
2 26 Feb 07 jari 149     
2 26 Feb 07 jari 150     /**
2 26 Feb 07 jari 151      * Aborts current SVM in progress.
2 26 Feb 07 jari 152      */
2 26 Feb 07 jari 153     public void abort() {
2 26 Feb 07 jari 154         stop = true;
2 26 Feb 07 jari 155     }
2 26 Feb 07 jari 156     
2 26 Feb 07 jari 157     /**
2 26 Feb 07 jari 158      * Sets underflow boolean
2 26 Feb 07 jari 159      * @param seenUnderflow boolean underflow value
2 26 Feb 07 jari 160      */
2 26 Feb 07 jari 161     private final void setSeenUnderflow(boolean seenUnderflow) {
2 26 Feb 07 jari 162         this.seenUnderflow = seenUnderflow;
2 26 Feb 07 jari 163     }
2 26 Feb 07 jari 164     
2 26 Feb 07 jari 165     /**
2 26 Feb 07 jari 166      * Returns underflow state.
2 26 Feb 07 jari 167      */
2 26 Feb 07 jari 168     private final boolean isSeenUnderflow() {
2 26 Feb 07 jari 169         return seenUnderflow;
2 26 Feb 07 jari 170     }
2 26 Feb 07 jari 171     
2 26 Feb 07 jari 172     /**
2 26 Feb 07 jari 173      * Creates a base kernal matrix. (Deprecated)
2 26 Feb 07 jari 174      * @param trainingMatrix
2 26 Feb 07 jari 175      * @return
2 26 Feb 07 jari 176      */
2 26 Feb 07 jari 177     private FloatMatrix computeBaseKernelMatrix(FloatMatrix trainingMatrix) {
2 26 Feb 07 jari 178         FloatMatrix kernelMatrix = new FloatMatrix(number_of_genes, number_of_genes);
2 26 Feb 07 jari 179         
2 26 Feb 07 jari 180         for (int row1 = 0; row1 < expMatrix.getRowDimension(); row1++) {
2 26 Feb 07 jari 181             for (int row2 = 0; row2 < trainingMatrix.getRowDimension(); row2++) {
2 26 Feb 07 jari 182                 kernelMatrix.set(row1, row2, ExperimentUtil.geneDistance(expMatrix, trainingMatrix, row1, row2, function, (float)1.0, false));
2 26 Feb 07 jari 183             }
2 26 Feb 07 jari 184         }
2 26 Feb 07 jari 185         return kernelMatrix;
2 26 Feb 07 jari 186     }
2 26 Feb 07 jari 187     
2 26 Feb 07 jari 188     
2 26 Feb 07 jari 189     /**
2 26 Feb 07 jari 190      * Creates a base kernal matrix with normalized expression
2 26 Feb 07 jari 191      * vectors using the dot product metric.
2 26 Feb 07 jari 192      * @param trainingMatrix
2 26 Feb 07 jari 193      * @return
2 26 Feb 07 jari 194      */
2 26 Feb 07 jari 195     private FloatMatrix computeNormalizedBaseKernelMatrix(FloatMatrix trainingMatrix){
2 26 Feb 07 jari 196         FloatMatrix normTrainingMatrix = new FloatMatrix(this.number_of_genes, this.number_of_samples);        
2 26 Feb 07 jari 197         FloatMatrix kernelMatrix = new FloatMatrix(number_of_genes, number_of_genes);
2 26 Feb 07 jari 198         float value;
2 26 Feb 07 jari 199         float sumOfSquares = 0;
2 26 Feb 07 jari 200         for( int row = 0; row < this.number_of_genes; row++){
2 26 Feb 07 jari 201             sumOfSquares = 0;
2 26 Feb 07 jari 202             for(int col = 0; col < this.number_of_samples; col++){
2 26 Feb 07 jari 203                 value = trainingMatrix.get(row,col);
2 26 Feb 07 jari 204                 if( !Float.isNaN(value) )
2 26 Feb 07 jari 205                     sumOfSquares += Math.pow(value, 2);
2 26 Feb 07 jari 206             }
2 26 Feb 07 jari 207             if(sumOfSquares != 0.0){
2 26 Feb 07 jari 208                 sumOfSquares = (float) Math.sqrt(sumOfSquares);
2 26 Feb 07 jari 209                 for(int col = 0; col < this.number_of_samples; col++){
2 26 Feb 07 jari 210                     normTrainingMatrix.set( row, col, trainingMatrix.get(row, col)/sumOfSquares);
2 26 Feb 07 jari 211                 }
2 26 Feb 07 jari 212             }
2 26 Feb 07 jari 213         }
2 26 Feb 07 jari 214         int N1 = expMatrix.getRowDimension();
2 26 Feb 07 jari 215         int N2 = trainingMatrix.getRowDimension();
2 26 Feb 07 jari 216         float kernalValue;
2 26 Feb 07 jari 217         for (int row1 = 0; row1 < N1; row1++) {
2 26 Feb 07 jari 218             for (int row2 = row1; row2 < N2; row2++) {
2 26 Feb 07 jari 219                 kernalValue = geneDotProduct(normTrainingMatrix, normTrainingMatrix, row1, row2);
2 26 Feb 07 jari 220                 kernelMatrix.set(row1, row2, kernalValue);
2 26 Feb 07 jari 221                 kernelMatrix.set(row2, row1, kernalValue);
2 26 Feb 07 jari 222             }
2 26 Feb 07 jari 223         }        
2 26 Feb 07 jari 224         return kernelMatrix;
2 26 Feb 07 jari 225     }
2 26 Feb 07 jari 226     
2 26 Feb 07 jari 227     
2 26 Feb 07 jari 228     /**
2 26 Feb 07 jari 229      * Creates
2 26 Feb 07 jari 230      * @param kernelMatrix
2 26 Feb 07 jari 231      * @return
2 26 Feb 07 jari 232      */
2 26 Feb 07 jari 233     private float[] createSelfKernelValues(FloatMatrix kernelMatrix) {
2 26 Feb 07 jari 234         float[] selfKernelValues = new float[kernelMatrix.getRowDimension()];
2 26 Feb 07 jari 235         extractSelfKernelValues(kernelMatrix, selfKernelValues);
2 26 Feb 07 jari 236         return selfKernelValues;
2 26 Feb 07 jari 237     }
2 26 Feb 07 jari 238     
2 26 Feb 07 jari 239     /**
2 26 Feb 07 jari 240      * Extract the diagonal from a given (square) kernel matrix.
2 26 Feb 07 jari 241      */
2 26 Feb 07 jari 242     private void extractSelfKernelValues(FloatMatrix kernelMatrix, float[] selfKernelValues) {
2 26 Feb 07 jari 243         final int size = kernelMatrix.getRowDimension();
2 26 Feb 07 jari 244         for (int i=0; i<size; i++) {
2 26 Feb 07 jari 245             selfKernelValues[i] = kernelMatrix.get(i, i);
2 26 Feb 07 jari 246         }
2 26 Feb 07 jari 247     }
2 26 Feb 07 jari 248     
2 26 Feb 07 jari 249     /**
2 26 Feb 07 jari 250      * Given three parameters, A, B and C, compute (B(X + C))^A.
2 26 Feb 07 jari 251      */
2 26 Feb 07 jari 252     private final float polynomialize(float power, float coefficient, float constant, float value) {
2 26 Feb 07 jari 253         value += constant;
2 26 Feb 07 jari 254         value *= coefficient;
2 26 Feb 07 jari 255         return(float)Math.pow(value, power);
2 26 Feb 07 jari 256     }
2 26 Feb 07 jari 257     
2 26 Feb 07 jari 258     /**
2 26 Feb 07 jari 259      * Given three parameters, A, B and C, replace each element X in a
2 26 Feb 07 jari 260      * given matrix by the value (B(X + C))^A.  Also perform the same
2 26 Feb 07 jari 261      * operation on two given arrays.
2 26 Feb 07 jari 262      */
2 26 Feb 07 jari 263     private void polynomializeMatrix(FloatMatrix kernelMatrix, float[] selfKernelValues, float power, float coefficient, float constant) {
2 26 Feb 07 jari 264         final int rows = kernelMatrix.getRowDimension();
2 26 Feb 07 jari 265         final int columns = kernelMatrix.getColumnDimension();
2 26 Feb 07 jari 266         for (int row=0; row < rows; row++) {
2 26 Feb 07 jari 267             for (int column=0; column < columns; column++) {
2 26 Feb 07 jari 268                 kernelMatrix.set(row, column, polynomialize(power, coefficient, constant, kernelMatrix.get(row, column)));                
2 26 Feb 07 jari 269             }
2 26 Feb 07 jari 270         }
2 26 Feb 07 jari 271         /* Polynomialize the self-kernel values. */
2 26 Feb 07 jari 272         for (int row=0; row < rows; row++) {
2 26 Feb 07 jari 273             selfKernelValues[row] = polynomialize(power, coefficient, constant, selfKernelValues[row]);
2 26 Feb 07 jari 274         }
2 26 Feb 07 jari 275     }
2 26 Feb 07 jari 276     
2 26 Feb 07 jari 277     /**
2 26 Feb 07 jari 278      * Classify a single example.
2 26 Feb 07 jari 279      */
2 26 Feb 07 jari 280     private final float classify(FloatMatrix kernelMatrix, float[] weights, int test) {
2 26 Feb 07 jari 281         float returnValue;
2 26 Feb 07 jari 282         float thisWeight;
2 26 Feb 07 jari 283         float thisValue;
2 26 Feb 07 jari 284         
2 26 Feb 07 jari 285         returnValue = 0.0f;
2 26 Feb 07 jari 286         for (int i = 0; i<weights.length; i++) {
2 26 Feb 07 jari 287             /* Get the current weight. */
2 26 Feb 07 jari 288             thisWeight = weights[i];
2 26 Feb 07 jari 289             /* If the weight is zero, skip. */
2 26 Feb 07 jari 290             if (thisWeight == 0.0) {
2 26 Feb 07 jari 291                 continue;
2 26 Feb 07 jari 292             }
2 26 Feb 07 jari 293             /* Compute the distance between the two examples. */
2 26 Feb 07 jari 294             thisValue = kernelMatrix.get(test, i);
2 26 Feb 07 jari 295             /* Weight the distance appropriately. This assumes that the
2 26 Feb 07 jari 296                classification of the training set example is encoded in the
2 26 Feb 07 jari 297                sign of the weight. */
2 26 Feb 07 jari 298             thisValue *= thisWeight;
2 26 Feb 07 jari 299             returnValue += thisValue;
2 26 Feb 07 jari 300         }
2 26 Feb 07 jari 301         return returnValue;
2 26 Feb 07 jari 302     }
2 26 Feb 07 jari 303     
2 26 Feb 07 jari 304     /**
2 26 Feb 07 jari 305      * Classify a list of examples.
2 26 Feb 07 jari 306      */
2 26 Feb 07 jari 307     private FloatMatrix classifyList(FloatMatrix kernelMatrix, float[] weights) {
2 26 Feb 07 jari 308         float thisDiscriminant;
2 26 Feb 07 jari 309         final int rows = kernelMatrix.getRowDimension();
2 26 Feb 07 jari 310         FloatMatrix discriminantMatrix = new FloatMatrix(rows, 2);
2 26 Feb 07 jari 311         for (int i=0; i<rows; i++) {
2 26 Feb 07 jari 312             /* Compute the discriminant. */
2 26 Feb 07 jari 313             thisDiscriminant = classify(kernelMatrix, weights, i);
2 26 Feb 07 jari 314             /* Store the classification. */
2 26 Feb 07 jari 315             if (thisDiscriminant >= 0.0) {
2 26 Feb 07 jari 316                 discriminantMatrix.set(i, 0, 1.0f);
2 26 Feb 07 jari 317             } else {
2 26 Feb 07 jari 318                 discriminantMatrix.set(i, 0, -1.0f);
2 26 Feb 07 jari 319             }
2 26 Feb 07 jari 320             /* Store the discriminant. */
2 26 Feb 07 jari 321             discriminantMatrix.set(i, 1, thisDiscriminant);
2 26 Feb 07 jari 322         }
2 26 Feb 07 jari 323         return discriminantMatrix;
2 26 Feb 07 jari 324     }
2 26 Feb 07 jari 325     
2 26 Feb 07 jari 326     
2 26 Feb 07 jari 327     /**
2 26 Feb 07 jari 328      * Normalize a kernel matrix.
2 26 Feb 07 jari 329      *
2 26 Feb 07 jari 330      * Let x be a vector of unnormalized features. Let \tilde{x} be the
2 26 Feb 07 jari 331      * vector of normalized features.
2 26 Feb 07 jari 332      *
2 26 Feb 07 jari 333      *         \tilde{x}_i = x_i/||x||
2 26 Feb 07 jari 334      *
2 26 Feb 07 jari 335      * where ||x|| is the norm of x. This means that ||\tilde{x}|| = 1 for
2 26 Feb 07 jari 336      * all x. What is happening when you normalize in this way is that all
2 26 Feb 07 jari 337      * feature vectors x are getting their lengths scaled so that they lie
2 26 Feb 07 jari 338      * on the surface of the unit sphere in 79 dimensional space.
2 26 Feb 07 jari 339      *
2 26 Feb 07 jari 340      * It turns out that this kind of normalization can be done in general
2 26 Feb 07 jari 341      * for any kernel. If K(x,y) is any kernel supplied by a user, then
2 26 Feb 07 jari 342      * you can normalize it by defining
2 26 Feb 07 jari 343      *
2 26 Feb 07 jari 344      *     \tilde{K}(x,y) = K(x,y)/\sqrt{K(x,x)}\sqrt{K(y,y)}
2 26 Feb 07 jari 345      *
2 26 Feb 07 jari 346      * It turns out that \sqrt{K(x,x)} is the norm of x in the feature
2 26 Feb 07 jari 347      * space.  Hence, for any x, the norm of x in the feature space
2 26 Feb 07 jari 348      * defined by \tilde{K} is
2 26 Feb 07 jari 349      *
2 26 Feb 07 jari 350      *           \sqrt{\tilde{K}(x,x)}  = 1
2 26 Feb 07 jari 351      *
2 26 Feb 07 jari 352      * When you normalize the kernel this way, it ensures that all points
2 26 Feb 07 jari 353      * are mapped to the surface of the unit ball in some (possibly infinite
2 26 Feb 07 jari 354      * dimensional) feature space.
2 26 Feb 07 jari 355      */
2 26 Feb 07 jari 356     private void normalizeKernelMatrix(FloatMatrix kernelMatrix, float[] selfKernelValues) {
2 26 Feb 07 jari 357         float rowDiag;
2 26 Feb 07 jari 358         float columnDiag;
2 26 Feb 07 jari 359         float cell;
2 26 Feb 07 jari 360         final int rows = kernelMatrix.getRowDimension();
2 26 Feb 07 jari 361         final int columns = kernelMatrix.getColumnDimension();
2 26 Feb 07 jari 362         for (int row=0; row<rows; row++) {
2 26 Feb 07 jari 363             rowDiag = (float)Math.sqrt(selfKernelValues[row]);
2 26 Feb 07 jari 364             for (int column=0; column<columns; column++) {
2 26 Feb 07 jari 365                 columnDiag = (float)Math.sqrt(selfKernelValues[column]);
2 26 Feb 07 jari 366                 cell = kernelMatrix.get(row, column);
2 26 Feb 07 jari 367                 cell /= rowDiag*columnDiag*1.0;
2 26 Feb 07 jari 368                 kernelMatrix.set(row, column, cell);
2 26 Feb 07 jari 369             }
2 26 Feb 07 jari 370         }
2 26 Feb 07 jari 371         for (int row=0; row<rows; row++) {
2 26 Feb 07 jari 372             selfKernelValues[row] = 1.0f;
2 26 Feb 07 jari 373         }
2 26 Feb 07 jari 374     }
2 26 Feb 07 jari 375     
2 26 Feb 07 jari 376     
2 26 Feb 07 jari 377     /**
2 26 Feb 07 jari 378      * Compute the median value in an array.
2 26 Feb 07 jari 379      *
2 26 Feb 07 jari 380      * Sorts the array as a side effect.
2 26 Feb 07 jari 381      */
2 26 Feb 07 jari 382     private float computeMedian(float[] array) {
2 26 Feb 07 jari 383         int   numberOfItems;
2 26 Feb 07 jari 384         float returnValue;
2 26 Feb 07 jari 385         
2 26 Feb 07 jari 386         // Sort the array.
2 26 Feb 07 jari 387         Arrays.sort(array);
2 26 Feb 07 jari 388         numberOfItems = array.length;
2 26 Feb 07 jari 389         if (numberOfItems % 2 == 1) {
2 26 Feb 07 jari 390             // If there are an odd number of elements, return the middle one.
2 26 Feb 07 jari 391             returnValue = array[numberOfItems/2];
2 26 Feb 07 jari 392         } else {
2 26 Feb 07 jari 393             // Otherwise, return the average of the two middle ones.
2 26 Feb 07 jari 394             returnValue  = array[numberOfItems/2 - 1];
2 26 Feb 07 jari 395             returnValue += array[numberOfItems/2];
2 26 Feb 07 jari 396             returnValue /= 2.0;
2 26 Feb 07 jari 397         }
2 26 Feb 07 jari 398         return returnValue;
2 26 Feb 07 jari 399     }
2 26 Feb 07 jari 400     
2 26 Feb 07 jari 401     /**
2 26 Feb 07 jari 402      * Define the squared distance as follows:
2 26 Feb 07 jari 403      *
2 26 Feb 07 jari 404      *     d^2(x,y) = K(x,x) - 2 K(x,y) + K(y,y)
2 26 Feb 07 jari 405      *
2 26 Feb 07 jari 406      * This is the squared Euclidean distance between points in feature
2 26 Feb 07 jari 407      * space.
2 26 Feb 07 jari 408      */
2 26 Feb 07 jari 409     private final float computeSquaredDistance(float Kxx, float Kxy, float Kyy) {
2 26 Feb 07 jari 410         return Kxx - 2*Kxy + Kyy;
2 26 Feb 07 jari 411     }
2 26 Feb 07 jari 412     
2 26 Feb 07 jari 413     /**
2 26 Feb 07 jari 414      * Set the width of a radial basis kernel to be the median of the
2 26 Feb 07 jari 415      * distances from each positive example to the nearest negative
2 26 Feb 07 jari 416      * example.
2 26 Feb 07 jari 417      *
2 26 Feb 07 jari 418      * This is only called during training, so we know the kernel matrix
2 26 Feb 07 jari 419      * is square.
2 26 Feb 07 jari 420      */
2 26 Feb 07 jari 421     private float computeTwoSquaredWidth(FloatMatrix kernelMatrix, int[] classes, float widthFactor) {
2 26 Feb 07 jari 422         int     numberOfPositives; // Total number of positive examples.
2 26 Feb 07 jari 423         float[] nearestNegatives;   // Distances to nearest negative example.
2 26 Feb 07 jari 424         int     positive;
2 26 Feb 07 jari 425         float   squaredDistance;    // Squared distance between two examples.
2 26 Feb 07 jari 426         float   returnValue;
2 26 Feb 07 jari 427         final int rows = kernelMatrix.getRowDimension();
2 26 Feb 07 jari 428         // Count the number of positive examples.
2 26 Feb 07 jari 429         numberOfPositives = 0;
2 26 Feb 07 jari 430         for (int i=0; i<rows; i++) {
2 26 Feb 07 jari 431             if (classes[i] == 1) {
2 26 Feb 07 jari 432                 numberOfPositives++;
2 26 Feb 07 jari 433             }
2 26 Feb 07 jari 434         }
2 26 Feb 07 jari 435         // Allocate the array of distances to nearest negatives.
2 26 Feb 07 jari 436         nearestNegatives = new float[numberOfPositives];
2 26 Feb 07 jari 437         // Find the nearest negative example for each positive.
2 26 Feb 07 jari 438         positive = -1;
2 26 Feb 07 jari 439         for (int i=0; i<rows; i++) {
2 26 Feb 07 jari 440             // Consider only positive examples.
2 26 Feb 07 jari 441             if (classes[i] == 1) {
2 26 Feb 07 jari 442                 positive++;
2 26 Feb 07 jari 443                 // Initialize this position to something very large.
2 26 Feb 07 jari 444                 nearestNegatives[positive] = Float.MAX_VALUE;
2 26 Feb 07 jari 445                 for (int j=0; j<rows; j++) {
2 26 Feb 07 jari 446                     // Consider only negative examples.
2 26 Feb 07 jari 447                     if (classes[j] != 1) {
2 26 Feb 07 jari 448                         // Compute the distance between these examples.
2 26 Feb 07 jari 449                         squaredDistance = computeSquaredDistance(kernelMatrix.get(i, i), kernelMatrix.get(i,j), kernelMatrix.get(j,j));
2 26 Feb 07 jari 450                         // Store the minimum.
2 26 Feb 07 jari 451                         if (nearestNegatives[positive] > squaredDistance) {
2 26 Feb 07 jari 452                             nearestNegatives[positive] = squaredDistance;
2 26 Feb 07 jari 453                         }
2 26 Feb 07 jari 454                     }
2 26 Feb 07 jari 455                 }
2 26 Feb 07 jari 456             }
2 26 Feb 07 jari 457         }
2 26 Feb 07 jari 458         // Find the median distance.
2 26 Feb 07 jari 459         returnValue = computeMedian(nearestNegatives);
2 26 Feb 07 jari 460         // Multiply in the given width factor and a factor of 2.
2 26 Feb 07 jari 461         returnValue *= 2.0 * widthFactor;
2 26 Feb 07 jari 462         // Return the result.
2 26 Feb 07 jari 463         return returnValue;
2 26 Feb 07 jari 464     }
2 26 Feb 07 jari 465     
2 26 Feb 07 jari 466     
2 26 Feb 07 jari 467     /**
2 26 Feb 07 jari 468      * Compute a radial basis function kernel, defined by
2 26 Feb 07 jari 469      *
2 26 Feb 07 jari 470      *     K(x,y) = exp{d^2(x,y)/2\sigma^2}
2 26 Feb 07 jari 471      *
2 26 Feb 07 jari 472      * for some constant \sigma (the width).
2 26 Feb 07 jari 473      */
2 26 Feb 07 jari 474     private float radialKernel(float twoSquaredWidth, float Kxx, float Kxy, float Kyy) {
2 26 Feb 07 jari 475         float returnValue;
2 26 Feb 07 jari 476         // Compute the squared distances between the examples.
2 26 Feb 07 jari 477         returnValue = computeSquaredDistance(Kxx, Kxy, Kyy);
2 26 Feb 07 jari 478         // Divide by twice sigma squared.
2 26 Feb 07 jari 479         returnValue /= twoSquaredWidth;
2 26 Feb 07 jari 480         // Exponentiate the opposite.
2 26 Feb 07 jari 481         returnValue = (float)Math.exp(-returnValue);
2 26 Feb 07 jari 482         // Make sure we didn't hit zero.
2 26 Feb 07 jari 483         if (returnValue == 0.0 && !isSeenUnderflow()) {
2 26 Feb 07 jari 484             setSeenUnderflow(true);
2 26 Feb 07 jari 485         }
2 26 Feb 07 jari 486         return returnValue;
2 26 Feb 07 jari 487     }
2 26 Feb 07 jari 488     
2 26 Feb 07 jari 489     /**
2 26 Feb 07 jari 490      * Convert each value in a given kernel matrix to a radial basis
2 26 Feb 07 jari 491      * version.
2 26 Feb 07 jari 492      */
2 26 Feb 07 jari 493     private void radializeMatrix(FloatMatrix kernelMatrix, float[] selfKernelValues, float twoSquaredWidth, float constant) {
2 26 Feb 07 jari 494         float radialValue;
2 26 Feb 07 jari 495         // Radialize each row of the matrix.
2 26 Feb 07 jari 496         final int rows = kernelMatrix.getRowDimension();
2 26 Feb 07 jari 497         final int columns = kernelMatrix.getColumnDimension();
2 26 Feb 07 jari 498         for (int row = 0; row < rows; row++) {
2 26 Feb 07 jari 499             for (int column = 0; column < columns; column++) {
2 26 Feb 07 jari 500                 // Compute the new value.
2 26 Feb 07 jari 501                 radialValue = radialKernel(twoSquaredWidth, selfKernelValues[row], kernelMatrix.get(row, column), selfKernelValues[column]);
2 26 Feb 07 jari 502                 // Add the constant back in.
2 26 Feb 07 jari 503                 radialValue += constant;
2 26 Feb 07 jari 504                 // Store the new value.
2 26 Feb 07 jari 505                 kernelMatrix.set(row, column, radialValue);
2 26 Feb 07 jari 506             }
2 26 Feb 07 jari 507         }
2 26 Feb 07 jari 508         /* Extract the radialized self-kernel values. This is necessary
2 26 Feb 07 jari 509            because these values are used during the computation of the
2 26 Feb 07 jari 510            diagonal factor.  This computation only occurs during training
2 26 Feb 07 jari 511            (not classification).  During training, the self-kernel values
2 26 Feb 07 jari 512            are the same for the rows and the columns.  Hence, we only need
2 26 Feb 07 jari 513            to extract values for the rows.  */
2 26 Feb 07 jari 514         extractSelfKernelValues(kernelMatrix, selfKernelValues);
2 26 Feb 07 jari 515     }
2 26 Feb 07 jari 516     
2 26 Feb 07 jari 517     /*
2 26 Feb 07 jari 518      * One good way to set the constants to add to the diagonal is
2 26 Feb 07 jari 519      *
2 26 Feb 07 jari 520      *   n+ = number of positive examples
2 26 Feb 07 jari 521      *   n- = number of negative examples
2 26 Feb 07 jari 522      *   N  = total number of examples
2 26 Feb 07 jari 523      *   k  = some constant (given by diagonal_factor)
2 26 Feb 07 jari 524      *
2 26 Feb 07 jari 525      * Then set
2 26 Feb 07 jari 526      *
2 26 Feb 07 jari 527      *   positive_diagonal = (n+/N) * k
2 26 Feb 07 jari 528      *   negative_diagonal = (n-/N) * k
2 26 Feb 07 jari 529      *
2 26 Feb 07 jari 530      */
2 26 Feb 07 jari 531     /**
2 26 Feb 07 jari 532      * Returns the diagonal constants
2 26 Feb 07 jari 533      */
2 26 Feb 07 jari 534     private float[] getDiagonalConstants(float[] selfKernelValues, int[] classes, float diagonalFactor) {
2 26 Feb 07 jari 535         // If the diagonal factor is zero, do nothing.
2 26 Feb 07 jari 536         if (diagonalFactor == 0) {
2 26 Feb 07 jari 537             return null;
2 26 Feb 07 jari 538         }
2 26 Feb 07 jari 539         int numberOfExamples = classes.length;
2 26 Feb 07 jari 540         int numberOfPositive = 0;
2 26 Feb 07 jari 541         int numberOfNegative = 0;
2 26 Feb 07 jari 542         // Find the median self-kernel value.
2 26 Feb 07 jari 543         float medianDiagonal = computeMedian(selfKernelValues);
2 26 Feb 07 jari 544         // Count the number of positives and negatives.
2 26 Feb 07 jari 545         for (int i=0; i<numberOfExamples; i++) {
2 26 Feb 07 jari 546             if (classes[i]==1) {
2 26 Feb 07 jari 547                 numberOfPositive++;
2 26 Feb 07 jari 548             } else {
2 26 Feb 07 jari 549                 numberOfNegative++;
2 26 Feb 07 jari 550             }
2 26 Feb 07 jari 551         }
2 26 Feb 07 jari 552         float[] diagonals = new float[2];
2 26 Feb 07 jari 553         diagonals[POSITIVE_DIAGONAL] = ((float)numberOfPositive/(float)numberOfExamples)*diagonalFactor*medianDiagonal;
2 26 Feb 07 jari 554         diagonals[NEGATIVE_DIAGONAL] = ((float)numberOfNegative/(float)numberOfExamples)*diagonalFactor*medianDiagonal;
2 26 Feb 07 jari 555         //diagonals[POSITIVE_DIAGONAL] = ((float)numberOfPositive/(float)numberOfExamples)*diagonalFactor;
2 26 Feb 07 jari 556         //diagonals[NEGATIVE_DIAGONAL] = ((float)numberOfNegative/(float)numberOfExamples)*diagonalFactor;
2 26 Feb 07 jari 557         return diagonals;
2 26 Feb 07 jari 558     }
2 26 Feb 07 jari 559     
2 26 Feb 07 jari 560     /*
2 26 Feb 07 jari 561      * Add a constant to the diagonal of the kernel matrix.
2 26 Feb 07 jari 562      *
2 26 Feb 07 jari 563      * This can be used to accomplish two things:
2 26 Feb 07 jari 564      *
2 26 Feb 07 jari 565      * (1) If the kernel is not positive definite, then adding a
2 26 Feb 07 jari 566      *     sufficiently large constant to the diagonal will make it so
2 26 Feb 07 jari 567      *     (although I don't know how to calculate a priori the proper
2 26 Feb 07 jari 568      *     value to add).
2 26 Feb 07 jari 569      *
2 26 Feb 07 jari 570      * (2) Adding to the diagonal also effectively scales the weights.  A
2 26 Feb 07 jari 571      *     larger constant makes the weights smaller.  Adding different
2 26 Feb 07 jari 572      *     constants for the positives and negatives has the same effect
2 26 Feb 07 jari 573      *     as placing different constraint ceilings.
2 26 Feb 07 jari 574      *
2 26 Feb 07 jari 575      */
2 26 Feb 07 jari 576     private void addToKernelDiagonal(FloatMatrix kernelMatrix, float[] selfKernelValues, int[] classes, float positiveDiagonal, float negativeDiagonal, float diagonalFactor) {
2 26 Feb 07 jari 577         float[] diagonals = getDiagonalConstants(selfKernelValues, classes, diagonalFactor);
2 26 Feb 07 jari 578         if (diagonals != null) {
2 26 Feb 07 jari 579             positiveDiagonal = diagonals[POSITIVE_DIAGONAL];
2 26 Feb 07 jari 580             negativeDiagonal = diagonals[NEGATIVE_DIAGONAL];
2 26 Feb 07 jari 581         }
2 26 Feb 07 jari 582         final int rows = kernelMatrix.getRowDimension();
2 26 Feb 07 jari 583         for (int row=0; row < rows; row++) {
2 26 Feb 07 jari 584             if (classes[row] == 1) {
2 26 Feb 07 jari 585                 kernelMatrix.set(row, row, kernelMatrix.get(row, row) + positiveDiagonal);
2 26 Feb 07 jari 586             } else {
2 26 Feb 07 jari 587                 kernelMatrix.set(row, row, kernelMatrix.get(row, row) + negativeDiagonal);
2 26 Feb 07 jari 588             }
2 26 Feb 07 jari 589             
2 26 Feb 07 jari 590             
2 26 Feb 07 jari 591         }
2 26 Feb 07 jari 592     }
2 26 Feb 07 jari 593     
2 26 Feb 07 jari 594     /*
2 26 Feb 07 jari 595      * The discriminant function is used to determine whether a given
2 26 Feb 07 jari 596      * example is classified positively or negatively.
2 26 Feb 07 jari 597      *
2 26 Feb 07 jari 598      * This function implements equation (4) from the paper cited above.
2 26 Feb 07 jari 599      */
2 26 Feb 07 jari 600     private float computeDiscriminant(FloatMatrix kernelMatrix, float[] weights, int[] classes, int thisItem) {
2 26 Feb 07 jari 601         float returnValue = 0.0f;
2 26 Feb 07 jari 602         for (int i=0; i<classes.length; i++) {
2 26 Feb 07 jari 603             /* Weight the distance appropriately and
2 26 Feb 07 jari 604             add or subtract, depending upon whether this is a positive or
2 26 Feb 07 jari 605             negative example. */
2 26 Feb 07 jari 606             if(!(Float.isNaN(kernelMatrix.get(thisItem, i) )))
2 26 Feb 07 jari 607                 returnValue += weights[i]*kernelMatrix.get(thisItem, i)*classes[i];
2 26 Feb 07 jari 608         }
2 26 Feb 07 jari 609         return returnValue;
2 26 Feb 07 jari 610     }
2 26 Feb 07 jari 611     
2 26 Feb 07 jari 612     /*
2 26 Feb 07 jari 613      * Keep a local copy of the weights array and signal if they've
2 26 Feb 07 jari 614      * stopped changing.
2 26 Feb 07 jari 615      *
2 26 Feb 07 jari 616      * Convergence is reached when the delta is below the convergence
2 26 Feb 07 jari 617      * threshold.
2 26 Feb 07 jari 618      */
2 26 Feb 07 jari 619     private boolean converged(AlgorithmEvent event, FloatMatrix kernelMatrix, float[] weights, int[] classes, float convergenceThreshold) {
2 26 Feb 07 jari 620         float objective;       // Current value of the objective.
2 26 Feb 07 jari 621         float delta = 0.0f;    // Change in objective.
2 26 Feb 07 jari 622         
2 26 Feb 07 jari 623         // Compute the new objective.
2 26 Feb 07 jari 624         objective = computeObjective(kernelMatrix, weights, classes);
2 26 Feb 07 jari 625         
2 26 Feb 07 jari 626         // Compute the change in objective.
2 26 Feb 07 jari 627         delta = objective - prevObjective;
2 26 Feb 07 jari 628         
2 26 Feb 07 jari 629         // Store this objective for next time.
2 26 Feb 07 jari 630         prevObjective = objective;
2 26 Feb 07 jari 631         
2 26 Feb 07 jari 632         if(!Float.isNaN(delta) && !Float.isInfinite(delta)){
2 26 Feb 07 jari 633             event.setFloatValue(Math.abs(delta));
2 26 Feb 07 jari 634             fireValueChanged(event);
2 26 Feb 07 jari 635         }
2 26 Feb 07 jari 636         
2 26 Feb 07 jari 637         return(Math.abs(delta) < convergenceThreshold);
2 26 Feb 07 jari 638     }
2 26 Feb 07 jari 639     
2 26 Feb 07 jari 640     /*
2 26 Feb 07 jari 641      * Compute the objective function, equation (7).
2 26 Feb 07 jari 642      */
2 26 Feb 07 jari 643     private float computeObjective(FloatMatrix kernelMatrix, float[] weights, int[] classes) {
2 26 Feb 07 jari 644         float sum = 0.0f;
2 26 Feb 07 jari 645         for (int i=0; i<classes.length; i++) {
2 26 Feb 07 jari 646             sum += weights[i]*(2.0-(computeDiscriminant(kernelMatrix, weights, classes, i)*classes[i]));
2 26 Feb 07 jari 647         }
2 26 Feb 07 jari 648         return sum;
2 26 Feb 07 jari 649     }
2 26 Feb 07 jari 650     
2 26 Feb 07 jari 651     /*
2 26 Feb 07 jari 652      * Update one item's weight.  This update rule maximizes the
2 26 Feb 07 jari 653      * constrained maximization of J(\lambda).  This function implements
2 26 Feb 07 jari 654      * equations (9) and (10) in Jaakkola et al.
2 26 Feb 07 jari 655      */
2 26 Feb 07 jari 656     private float updateWeight(FloatMatrix kernelMatrix, float[] weights, int[] classes, float constraint, boolean constrainWeights, int thisItem) {
2 26 Feb 07 jari 657         float thisDiscriminant;
2 26 Feb 07 jari 658         float selfDistance;
2 26 Feb 07 jari 659         float thisWeight;
2 26 Feb 07 jari 660         float newWeight;
2 26 Feb 07 jari 661         float thisClass;
2 26 Feb 07 jari 662         
2 26 Feb 07 jari 663         thisDiscriminant = computeDiscriminant(kernelMatrix, weights, classes, thisItem);
2 26 Feb 07 jari 664         selfDistance = kernelMatrix.get(thisItem, thisItem);
2 26 Feb 07 jari 665         thisWeight = weights[thisItem];
2 26 Feb 07 jari 666         // Weight negative examples oppositely.
2 26 Feb 07 jari 667         thisClass = classes[thisItem];
2 26 Feb 07 jari 668         // This is equation (8).
2 26 Feb 07 jari 669         newWeight = 1.0f-(thisClass*thisDiscriminant)+(thisWeight*selfDistance);
2 26 Feb 07 jari 670         // Divide by k(x,x), checking for divide-by-zero.
2 26 Feb 07 jari 671         if (selfDistance == 0.0) {
2 26 Feb 07 jari 672             newWeight /= newWeight;  /* ?????????????????????????????*/
2 26 Feb 07 jari 673         } else {
2 26 Feb 07 jari 674             newWeight /= selfDistance;
2 26 Feb 07 jari 675         }
2 26 Feb 07 jari 676         if (selfDistance != 0.0) {
2 26 Feb 07 jari 677             thisWeight = weights[thisItem];
2 26 Feb 07 jari 678             weights[thisItem] = newWeight;
2 26 Feb 07 jari 679             weights[thisItem] = newWeight;
2 26 Feb 07 jari 680             thisDiscriminant  = computeDiscriminant(kernelMatrix, weights, classes, thisItem);
2 26 Feb 07 jari 681             weights[thisItem] = thisWeight;
2 26 Feb 07 jari 682         }
2 26 Feb 07 jari 683         // Constrain the weight.
2 26 Feb 07 jari 684         if (constrainWeights && (newWeight > constraint)) {
2 26 Feb 07 jari 685             newWeight = constraint;
2 26 Feb 07 jari 686         } else if (newWeight < 0.0) {
2 26 Feb 07 jari 687             newWeight = 0.0f;
2 26 Feb 07 jari 688         }
2 26 Feb 07 jari 689         
2 26 Feb 07 jari 690         
2 26 Feb 07 jari 691         return newWeight;
2 26 Feb 07 jari 692     }
2 26 Feb 07 jari 693     
2 26 Feb 07 jari 694     /*
2 26 Feb 07 jari 695      * Optimize the weights so that the discriminant function puts the
2 26 Feb 07 jari 696      * negatives close to -1 and the positives close to +1.
2 26 Feb 07 jari 697      */
2 26 Feb 07 jari 698     private long optimizeWeights(FloatMatrix kernelMatrix, float[] weights, int[] classes, int seed, float positiveConstraint, float negativeConstraint, float convergenceThreshold, boolean constrainWeights) throws AlgorithmException {
2 26 Feb 07 jari 699         int   randItem;
2 26 Feb 07 jari 700         long  iter = 0;
2 26 Feb 07 jari 701         float newWeight = 0.5f;
2 26 Feb 07 jari 702         float constraint;
2 26 Feb 07 jari 703         float INITIAL_VALUE = 0.5f; // Initial value for all weights.
2 26 Feb 07 jari 704         
2 26 Feb 07 jari 705         // Initialize the weights.
2 26 Feb 07 jari 706         for (int i=0; i<weights.length; i++) {
2 26 Feb 07 jari 707             weights[i] = INITIAL_VALUE;
2 26 Feb 07 jari 708         }
2 26 Feb 07 jari 709         
2 26 Feb 07 jari 710         AlgorithmEvent event = new AlgorithmEvent(this, AlgorithmEvent.MONITOR_VALUE);
2 26 Feb 07 jari 711         fireValueChanged(event); // to show monitor
2 26 Feb 07 jari 712         Random random = new Random(seed);
2 26 Feb 07 jari 713         // Iteratively improve the weights until convergence.
2 26 Feb 07 jari 714         while (!converged(event, kernelMatrix, weights, classes, convergenceThreshold)) {
2 26 Feb 07 jari 715             isStop();
2 26 Feb 07 jari 716             for (int i=0; i < classes.length; i++) {
2 26 Feb 07 jari 717                 // Randomly select a weight to update.
2 26 Feb 07 jari 718                 randItem = random.nextInt(classes.length);
2 26 Feb 07 jari 719                 // Set the constraint, based upon the class of this item.
2 26 Feb 07 jari 720                 if (classes[randItem] == 1) {
2 26 Feb 07 jari 721                     constraint = positiveConstraint;
2 26 Feb 07 jari 722                 } else {
2 26 Feb 07 jari 723                     constraint = negativeConstraint;
2 26 Feb 07 jari 724                 }
2 26 Feb 07 jari 725                 // Calculate the new weight.
2 26 Feb 07 jari 726                 newWeight = updateWeight(kernelMatrix, weights, classes, constraint, constrainWeights, randItem);
2 26 Feb 07 jari 727                 weights[randItem] = newWeight;
2 26 Feb 07 jari 728             }
2 26 Feb 07 jari 729             
2 26 Feb 07 jari 730             if(iter > 1000){
2 26 Feb 07 jari 731                 if(JOptionPane.showConfirmDialog( null, "                                Warning: 1000 iterations have failed to optimize weights.\n"+
2 26 Feb 07 jari 732                 "Please press OK to continue analysis using current weights OR press CANCEL to abort and try new parameters.\n","Weight Optimization Warning", JOptionPane.WARNING_MESSAGE, JOptionPane.WARNING_MESSAGE)
2 26 Feb 07 jari 733                 == JOptionPane.OK_OPTION)
2 26 Feb 07 jari 734                     break;
2 26 Feb 07 jari 735                 else
2 26 Feb 07 jari 736                     this.stop = true;
2 26 Feb 07 jari 737             }
2 26 Feb 07 jari 738             iter++;
2 26 Feb 07 jari 739         }
2 26 Feb 07 jari 740         return iter+1;
2 26 Feb 07 jari 741     }
2 26 Feb 07 jari 742     
2 26 Feb 07 jari 743     /*
2 26 Feb 07 jari 744      * Encode the classifications in the weights by multiplying the negative
2 26 Feb 07 jari 745      * examples by -1.
2 26 Feb 07 jari 746      */
2 26 Feb 07 jari 747     private void signWeights(float[] weights, int[] classes) {
2 26 Feb 07 jari 748         for (int i=0; i<classes.length; i++) {
2 26 Feb 07 jari 749             weights[i] = weights[i]*classes[i];
2 26 Feb 07 jari 750         }
2 26 Feb 07 jari 751     }
2 26 Feb 07 jari 752     
2 26 Feb 07 jari 753     /**
2 26 Feb 07 jari 754      * Creates classification FloatMatrix of class distribution ints (pos == 1, neg == -1) and discriminant values
2 26 Feb 07 jari 755      */
2 26 Feb 07 jari 756     private FloatMatrix classify(FloatMatrix trainingMatrix, float[] weights, float coefficient, float constant, float power) {
2 26 Feb 07 jari 757         
2 26 Feb 07 jari 758         AlgorithmEvent event = new AlgorithmEvent(this, AlgorithmEvent.PROGRESS_VALUE, 0);
2 26 Feb 07 jari 759         sendEvent(event, "CLASSIFYING\n");
2 26 Feb 07 jari 760         event.setDescription("Computing base kernel matrix\n");
2 26 Feb 07 jari 761         fireValueChanged(event);
2 26 Feb 07 jari 762         
2 26 Feb 07 jari 763         FloatMatrix kernelMatrix = computeNormalizedBaseKernelMatrix(trainingMatrix);
2 26 Feb 07 jari 764         
2 26 Feb 07 jari 765         float[] selfKernelValues = createSelfKernelValues(kernelMatrix);
2 26 Feb 07 jari 766         
2 26 Feb 07 jari 767         event.setDescription("Polynomializing kernel matrix\n");
2 26 Feb 07 jari 768         fireValueChanged(event);
2 26 Feb 07 jari 769         
2 26 Feb 07 jari 770         polynomializeMatrix(kernelMatrix, selfKernelValues, power, coefficient, constant);
2 26 Feb 07 jari 771         
2 26 Feb 07 jari 772         FloatMatrix discriminantMatrix = classifyList(kernelMatrix, weights);
2 26 Feb 07 jari 773         
2 26 Feb 07 jari 774         return discriminantMatrix;
2 26 Feb 07 jari 775     }
2 26 Feb 07 jari 776     
2 26 Feb 07 jari 777     /**
2 26 Feb 07 jari 778      * Trains SVM and returns float [] of weights
2 26 Feb 07 jari 779      */
2 26 Feb 07 jari 780     private float[] train(FloatMatrix trainingMatrix, int[] classes, int seed, boolean normalize, boolean radial, float coefficient, float constant, float power, float widthFactor, float positiveDiagonal, float negativeDiagonal, float diagonalFactor, float positiveConstraint, float negativeConstraint, float convergenceThreshold, boolean constrainWeights) throws AlgorithmException {
2 26 Feb 07 jari 781         
2 26 Feb 07 jari 782         AlgorithmEvent event = new AlgorithmEvent(this, AlgorithmEvent.PROGRESS_VALUE, 0);
2 26 Feb 07 jari 783         sendEvent(event, "TRAINING SVM\n");
2 26 Feb 07 jari 784         sendEvent(event, "Computing base kernel matrix\n");
2 26 Feb 07 jari 785         //FloatMatrix kernelMatrix = computeBaseKernelMatrix(trainingMatrix);
2 26 Feb 07 jari 786         
2 26 Feb 07 jari 787         
2 26 Feb 07 jari 788         FloatMatrix kernelMatrix = computeNormalizedBaseKernelMatrix(trainingMatrix);
2 26 Feb 07 jari 789         
2 26 Feb 07 jari 790         
2 26 Feb 07 jari 791         sendEvent(event, "Extract the diagonal from the kernel matrix.\n");
2 26 Feb 07 jari 792         float[] selfKernelValues = createSelfKernelValues(kernelMatrix);
2 26 Feb 07 jari 793         
2 26 Feb 07 jari 794         isStop();
2 26 Feb 07 jari 795         
2 26 Feb 07 jari 796         if (normalize) {
2 26 Feb 07 jari 797             sendEvent(event, "Normalizing kernel matrix\n");
2 26 Feb 07 jari 798             normalizeKernelMatrix(kernelMatrix, selfKernelValues);
2 26 Feb 07 jari 799         }
2 26 Feb 07 jari 800         
2 26 Feb 07 jari 801         isStop();
2 26 Feb 07 jari 802         
2 26 Feb 07 jari 803         
2 26 Feb 07 jari 804         if(!radial){
2 26 Feb 07 jari 805             sendEvent(event, "Polynomializing kernel matrix\n");
2 26 Feb 07 jari 806             polynomializeMatrix(kernelMatrix, selfKernelValues, power, coefficient, constant);
2 26 Feb 07 jari 807         }
2 26 Feb 07 jari 808         
2 26 Feb 07 jari 809         if (radial) {
2 26 Feb 07 jari 810             sendEvent(event, "Convert to a radial basis kernel.\n");
2 26 Feb 07 jari 811             float twoSquaredWidth = computeTwoSquaredWidth(kernelMatrix, classes, widthFactor);
2 26 Feb 07 jari 812             radializeMatrix(kernelMatrix, selfKernelValues, twoSquaredWidth, constant);
2 26 Feb 07 jari 813         }
2 26 Feb 07 jari 814         
2 26 Feb 07 jari 815         isStop();
2 26 Feb 07 jari 816         
2 26 Feb 07 jari 817         // Add constants to the diagonal.
2 26 Feb 07 jari 818         sendEvent(event, "Adding constants to kernel matrix\n");
2 26 Feb 07 jari 819         addToKernelDiagonal(kernelMatrix, selfKernelValues, classes, positiveDiagonal, negativeDiagonal, diagonalFactor);
2 26 Feb 07 jari 820         
2 26 Feb 07 jari 821         //printKernelDiagonal(kernelMatrix);
2 26 Feb 07 jari 822         
2 26 Feb 07 jari 823         isStop();
2 26 Feb 07 jari 824         
2 26 Feb 07 jari 825         // Initialize the weights to zeroes.
2 26 Feb 07 jari 826         float[] weights = new float[number_of_genes];
2 26 Feb 07 jari 827         // Optimize the weights.
2 26 Feb 07 jari 828         sendEvent(event, "Optimizing weights\n");
2 26 Feb 07 jari 829         optimizeWeights(kernelMatrix, weights, classes, seed, positiveConstraint, negativeConstraint, convergenceThreshold, constrainWeights);
2 26 Feb 07 jari 830         
2 26 Feb 07 jari 831         // Encode the classifications as the signs of the weights.
2 26 Feb 07 jari 832         sendEvent(event, "Encoding the classifications as the signs of the weights.\n");
2 26 Feb 07 jari 833         signWeights(weights, classes);
2 26 Feb 07 jari 834         
2 26 Feb 07 jari 835         return weights;
2 26 Feb 07 jari 836     }
2 26 Feb 07 jari 837     
2 26 Feb 07 jari 838     private void isStop() throws AbortException {
2 26 Feb 07 jari 839         if (stop) {
2 26 Feb 07 jari 840             throw new AbortException();
2 26 Feb 07 jari 841         }
2 26 Feb 07 jari 842     }
2 26 Feb 07 jari 843     
2 26 Feb 07 jari 844     private void sendEvent(AlgorithmEvent event, String description) {
2 26 Feb 07 jari 845         event.setDescription(description);
2 26 Feb 07 jari 846         fireValueChanged(event);
2 26 Feb 07 jari 847     }
2 26 Feb 07 jari 848     
2 26 Feb 07 jari 849     /**
2 26 Feb 07 jari 850      * Returns positive element index list
2 26 Feb 07 jari 851      */
2 26 Feb 07 jari 852     private int [] getPositives(FloatMatrix matrix){
2 26 Feb 07 jari 853         int cnt = 0;
2 26 Feb 07 jari 854         
2 26 Feb 07 jari 855         for(int i = 0; i < matrix.getRowDimension(); i++){
2 26 Feb 07 jari 856             if( matrix.get( i, 0 ) == 1.0 )
2 26 Feb 07 jari 857                 cnt++;
2 26 Feb 07 jari 858         }
2 26 Feb 07 jari 859         
2 26 Feb 07 jari 860         int [] pos = new int[cnt];
2 26 Feb 07 jari 861         cnt = 0;
2 26 Feb 07 jari 862         
2 26 Feb 07 jari 863         for(int i = 0; i < matrix.getRowDimension(); i++){
2 26 Feb 07 jari 864             if( matrix.get( i, 0 ) == 1.0 ){
2 26 Feb 07 jari 865                 pos[cnt] = i;
2 26 Feb 07 jari 866                 cnt++;
2 26 Feb 07 jari 867             }
2 26 Feb 07 jari 868         }
2 26 Feb 07 jari 869         return pos;
2 26 Feb 07 jari 870     }
2 26 Feb 07 jari 871     
2 26 Feb 07 jari 872     /**
2 26 Feb 07 jari 873      * Returns negative element index list
2 26 Feb 07 jari 874      */
2 26 Feb 07 jari 875     private int [] getNegatives(FloatMatrix matrix){
2 26 Feb 07 jari 876         int cnt = 0;
2 26 Feb 07 jari 877         for(int i = 0; i < matrix.getRowDimension(); i++){
2 26 Feb 07 jari 878             if( matrix.get( i, 0 ) <= 0 )
2 26 Feb 07 jari 879                 cnt++;
2 26 Feb 07 jari 880         }
2 26 Feb 07 jari 881         
2 26 Feb 07 jari 882         int [] neg = new int[cnt];
2 26 Feb 07 jari 883         cnt = 0;
2 26 Feb 07 jari 884         
2 26 Feb 07 jari 885         for(int i = 0; i < matrix.getRowDimension(); i++){
2 26 Feb 07 jari 886             if( matrix.get( i, 0 ) <= 0 ){
2 26 Feb 07 jari 887                 neg[cnt] = i;
2 26 Feb 07 jari 888                 cnt++;
2 26 Feb 07 jari 889             }
2 26 Feb 07 jari 890         }
2 26 Feb 07 jari 891         return neg;
2 26 Feb 07 jari 892     }
2 26 Feb 07 jari 893     
2 26 Feb 07 jari 894     
2 26 Feb 07 jari 895     /**
2 26 Feb 07 jari 896      *  Internal gene dot product
2 26 Feb 07 jari 897      */
2 26 Feb 07 jari 898     private float geneDotProduct(FloatMatrix matrix, FloatMatrix M, int g1, int g2) {
2 26 Feb 07 jari 899         if (M == null) {
2 26 Feb 07 jari 900             M = matrix;
2 26 Feb 07 jari 901         }
2 26 Feb 07 jari 902         int k=matrix.getColumnDimension();
2 26 Feb 07 jari 903         int n=0;
2 26 Feb 07 jari 904         double sum=0.0;
2 26 Feb 07 jari 905         for (int i=0; i<k; i++) {
2 26 Feb 07 jari 906             if ((!Float.isNaN(matrix.get(g1,i))) && (!Float.isNaN(M.get(g2,i)))) {
2 26 Feb 07 jari 907                 sum+=matrix.get(g1,i)*M.get(g2,i);
2 26 Feb 07 jari 908                 n++;
2 26 Feb 07 jari 909             }
2 26 Feb 07 jari 910         }
2 26 Feb 07 jari 911         return(float)(sum);
2 26 Feb 07 jari 912     }
2 26 Feb 07 jari 913     
2 26 Feb 07 jari 914     
2 26 Feb 07 jari 915     /**
2 26 Feb 07 jari 916      *  Retuns means values for each column within positives and negatives
2 26 Feb 07 jari 917      */
2 26 Feb 07 jari 918     private FloatMatrix getMeans(FloatMatrix discMatrix){
2 26 Feb 07 jari 919         int numSamples = this.expMatrix.getColumnDimension();
2 26 Feb 07 jari 920         int numGenes = this.expMatrix.getRowDimension();
2 26 Feb 07 jari 921         
2 26 Feb 07 jari 922         FloatMatrix means = new FloatMatrix(2, numSamples);
2 26 Feb 07 jari 923         float posMean = 0;
2 26 Feb 07 jari 924         float negMean = 0;
2 26 Feb 07 jari 925         float value;
2 26 Feb 07 jari 926         int posCnt = 0;
2 26 Feb 07 jari 927         int negCnt = 0;
2 26 Feb 07 jari 928         float c;
2 26 Feb 07 jari 929         
2 26 Feb 07 jari 930         for(int j = 0; j < numSamples; j++){
2 26 Feb 07 jari 931             for(int i = 0; i < numGenes; i++){
2 26 Feb 07 jari 932                 
2 26 Feb 07 jari 933                 c = discMatrix.get(i,0);
2 26 Feb 07 jari 934                 if(c == 1){
2 26 Feb 07 jari 935                     value = this.expMatrix.get(i,j);
2 26 Feb 07 jari 936                     if(!Float.isNaN(value)){
2 26 Feb 07 jari 937                         posCnt++;
2 26 Feb 07 jari 938                         posMean += value;
2 26 Feb 07 jari 939                     }
2 26 Feb 07 jari 940                 }
2 26 Feb 07 jari 941                 else{
2 26 Feb 07 jari 942                     
2 26 Feb 07 jari 943                     value = this.expMatrix.get(i,j);
2 26 Feb 07 jari 944                     if(!Float.isNaN(value)){
2 26 Feb 07 jari 945                         negCnt++;
2 26 Feb 07 jari 946                         negMean += value;
2 26 Feb 07 jari 947                     }
2 26 Feb 07 jari 948                 }
2 26 Feb 07 jari 949             }
2 26 Feb 07 jari 950             means.set( 0, j, (float)(posCnt != 0 ? posMean/posCnt : 0.0f));
2 26 Feb 07 jari 951             means.set( 1, j, (float)(negCnt != 0 ? negMean/negCnt : 0.0f));
2 26 Feb 07 jari 952             posCnt = 0;
2 26 Feb 07 jari 953             negCnt = 0;
2 26 Feb 07 jari 954             posMean = 0;
2 26 Feb 07 jari 955             negMean = 0;
2 26 Feb 07 jari 956         }
2 26 Feb 07 jari 957         return means;
2 26 Feb 07 jari 958     }
2 26 Feb 07 jari 959     
2 26 Feb 07 jari 960     /**
2 26 Feb 07 jari 961      *  Retuns variance values for each column within positives and negatives
2 26 Feb 07 jari 962      */
2 26 Feb 07 jari 963     private FloatMatrix getVariance(FloatMatrix discMatrix, FloatMatrix means){
2 26 Feb 07 jari 964         int numSamples = this.expMatrix.getColumnDimension();
2 26 Feb 07 jari 965         int numGenes = this.expMatrix.getRowDimension();
2 26 Feb 07 jari 966         FloatMatrix vars = new FloatMatrix(2, numSamples);
2 26 Feb 07 jari 967         float value;
2 26 Feb 07 jari 968         float c;
2 26 Feb 07 jari 969         float mean;
2 26 Feb 07 jari 970         float ssePos = 0;
2 26 Feb 07 jari 971         int posCnt = 0;
2 26 Feb 07 jari 972         float sseNeg = 0;
2 26 Feb 07 jari 973         int negCnt = 0;
2 26 Feb 07 jari 974         for(int i = 0; i < numSamples; i++){
2 26 Feb 07 jari 975             
2 26 Feb 07 jari 976             for(int j = 0; j < numGenes; j++){
2 26 Feb 07 jari 977                 c = discMatrix.get(j, 0);
2 26 Feb 07 jari 978                 
2 26 Feb 07 jari 979                 if(c == 1){
2 26 Feb 07 jari 980                     value = expMatrix.get(j,i);
2 26 Feb 07 jari 981                     if(!Float.isNaN(value)){
2 26 Feb 07 jari 982                         ssePos += Math.pow(value - means.get(0, i), 2);
2 26 Feb 07 jari 983                         posCnt++;
2 26 Feb 07 jari 984                     }
2 26 Feb 07 jari 985                 }
2 26 Feb 07 jari 986                 else{
2 26 Feb 07 jari 987                     value = expMatrix.get(j,i);
2 26 Feb 07 jari 988                     if(!Float.isNaN(value)){                        
2 26 Feb 07 jari 989                         sseNeg += Math.pow(value - means.get(1, i), 2);
2 26 Feb 07 jari 990                         negCnt++;
2 26 Feb 07 jari 991                     }
2 26 Feb 07 jari 992                 }
2 26 Feb 07 jari 993             }
2 26 Feb 07 jari 994             vars.set( 0, i, (float)(posCnt > 1 ? Math.sqrt(ssePos/(posCnt - 1)) : 0.0f));
2 26 Feb 07 jari 995             vars.set( 1, i, (float)(negCnt > 1 ? Math.sqrt(sseNeg/(negCnt - 1)) : 0.0f));
2 26 Feb 07 jari 996             posCnt = 0;
2 26 Feb 07 jari 997             negCnt = 0;
2 26 Feb 07 jari 998             ssePos = 0;
2 26 Feb 07 jari 999             sseNeg = 0;
2 26 Feb 07 jari 1000         }
2 26 Feb 07 jari 1001         return vars;
2 26 Feb 07 jari 1002     }
2 26 Feb 07 jari 1003     
2 26 Feb 07 jari 1004     /**
2 26 Feb 07 jari 1005      * Creates HCL results
2 26 Feb 07 jari 1006      */
2 26 Feb 07 jari 1007     private NodeValueList calculateHierarchicalTree(int[] features, int method, boolean genes, boolean experiments) throws AlgorithmException {
2 26 Feb 07 jari 1008         NodeValueList nodeList = new NodeValueList();
2 26 Feb 07 jari 1009         AlgorithmData data = new AlgorithmData();
2 26 Feb 07 jari 1010         FloatMatrix experiment;
2 26 Feb 07 jari 1011         if(svmGenes)
2 26 Feb 07 jari 1012             experiment = getSubExperiment(this.expMatrix, features);
2 26 Feb 07 jari 1013         else
2 26 Feb 07 jari 1014             experiment = getSubExperimentReducedCols(this.expMatrix, features);
2 26 Feb 07 jari 1015         
2 26 Feb 07 jari 1016         data.addMatrix("experiment", experiment);
2 26 Feb 07 jari 1017         System.out.println("In SVM algorithm , metric for HCL = "+this.function  );
2 26 Feb 07 jari 1018         data.addParam("hcl-distance-function", String.valueOf(this.function));
2 26 Feb 07 jari 1019         data.addParam("hcl-distance-absolute", String.valueOf(this.absolute));
2 26 Feb 07 jari 1020         data.addParam("method-linkage", String.valueOf(method));
2 26 Feb 07 jari 1021         HCL hcl = new HCL();
2 26 Feb 07 jari 1022         AlgorithmData result;
2 26 Feb 07 jari 1023         
2 26 Feb 07 jari 1024         if (genes) {
2 26 Feb 07 jari 1025             data.addParam("calculate-genes", String.valueOf(true));
2 26 Feb 07 jari 1026             result = hcl.execute(data);
2 26 Feb 07 jari 1027             validate(result);
2 26 Feb 07 jari 1028             addNodeValues(nodeList, result);
2 26 Feb 07 jari 1029         }
2 26 Feb 07 jari 1030         if (experiments) {
2 26 Feb 07 jari 1031             data.addParam("calculate-genes", String.valueOf(false));
2 26 Feb 07 jari 1032             result = hcl.execute(data);
2 26 Feb 07 jari 1033             int [] nodes = result.getIntArray("node-order");
2 26 Feb 07 jari 1034             validate(result);
2 26 Feb 07 jari 1035             addNodeValues(nodeList, result);
2 26 Feb 07 jari 1036         }
2 26 Feb 07 jari 1037         return nodeList;
2 26 Feb 07 jari 1038     }
2 26 Feb 07 jari 1039     
2 26 Feb 07 jari 1040     
2 26 Feb 07 jari 1041     /**
2 26 Feb 07 jari 1042      * Accumulates hcl results
2 26 Feb 07 jari 1043      */
2 26 Feb 07 jari 1044     private void addNodeValues(NodeValueList target_list, AlgorithmData source_result) {
2 26 Feb 07 jari 1045         target_list.addNodeValue(new NodeValue("child-1-array", source_result.getIntArray("child-1-array")));
2 26 Feb 07 jari 1046         target_list.addNodeValue(new NodeValue("child-2-array", source_result.getIntArray("child-2-array")));
2 26 Feb 07 jari 1047         target_list.addNodeValue(new NodeValue("node-order", source_result.getIntArray("node-order")));
2 26 Feb 07 jari 1048         target_list.addNodeValue(new NodeValue("height", source_result.getMatrix("height").getRowPackedCopy()));
2 26 Feb 07 jari 1049     }
2 26 Feb 07 jari 1050     
2 26 Feb 07 jari 1051     /**
2 26 Feb 07 jari 1052      *  Gets sub experiment (cluster membership only, dictated by features)
2 26 Feb 07 jari 1053      */
2 26 Feb 07 jari 1054     private FloatMatrix getSubExperiment(FloatMatrix experiment, int[] features) {
2 26 Feb 07 jari 1055         FloatMatrix subExperiment = new FloatMatrix(features.length, experiment.getColumnDimension());
2 26 Feb 07 jari 1056         for (int i=0; i<features.length; i++) {
2 26 Feb 07 jari 1057             subExperiment.A[i] = experiment.A[features[i]];
2 26 Feb 07 jari 1058         }
2 26 Feb 07 jari 1059         return subExperiment;
2 26 Feb 07 jari 1060     }
2 26 Feb 07 jari 1061     
2 26 Feb 07 jari 1062     /**
2 26 Feb 07 jari 1063      *  Creates a matrix with reduced columns (samples) as during experiment classification
2 26 Feb 07 jari 1064      */
2 26 Feb 07 jari 1065     private FloatMatrix getSubExperimentReducedCols(FloatMatrix experiment, int[] features) {
2 26 Feb 07 jari 1066         FloatMatrix copyMatrix = experiment.copy();
2 26 Feb 07 jari 1067         FloatMatrix subExperiment = new FloatMatrix(features.length, copyMatrix.getColumnDimension());
2 26 Feb 07 jari 1068         for (int i=0; i<features.length; i++) {
2 26 Feb 07 jari 1069             subExperiment.A[i] = copyMatrix.A[features[i]];
2 26 Feb 07 jari 1070         }
2 26 Feb 07 jari 1071         subExperiment = subExperiment.transpose();
2 26 Feb 07 jari 1072         return subExperiment;
2 26 Feb 07 jari 1073     }
2 26 Feb 07 jari 1074     
2 26 Feb 07 jari 1075     /**
2 26 Feb 07 jari 1076      * Checks the result of hcl algorithm calculation.
2 26 Feb 07 jari 1077      * @throws AlgorithmException, if the result is incorrect.
2 26 Feb 07 jari 1078      */
2 26 Feb 07 jari 1079     private void validate(AlgorithmData result) throws AlgorithmException {
2 26 Feb 07 jari 1080         if (result.getIntArray("child-1-array") == null) {
2 26 Feb 07 jari 1081             throw new AlgorithmException("parameter 'child-1-array' is null");
2 26 Feb 07 jari 1082         }
2 26 Feb 07 jari 1083         if (result.getIntArray("child-2-array") == null) {
2 26 Feb 07 jari 1084             throw new AlgorithmException("parameter 'child-2-array' is null");
2 26 Feb 07 jari 1085         }
2 26 Feb 07 jari 1086         if (result.getIntArray("node-order") == null) {
2 26 Feb 07 jari 1087             throw new AlgorithmException("parameter 'node-order' is null");
2 26 Feb 07 jari 1088         }
2 26 Feb 07 jari 1089         if (result.getMatrix("height") == null) {
2 26 Feb 07 jari 1090             throw new AlgorithmException("parameter 'height' is null");
2 26 Feb 07 jari 1091         }
2 26 Feb 07 jari 1092     }
2 26 Feb 07 jari 1093     
2 26 Feb 07 jari 1094     private int[] convert2int(ArrayList source) {
2 26 Feb 07 jari 1095         int[] int_matrix = new int[source.size()];
2 26 Feb 07 jari 1096         for (int i=0; i<int_matrix.length; i++) {
2 26 Feb 07 jari 1097             int_matrix[i] = (int)((Float)source.get(i)).floatValue();
2 26 Feb 07 jari 1098         }
2 26 Feb 07 jari 1099         return int_matrix;
2 26 Feb 07 jari 1100     }
2 26 Feb 07 jari 1101     
2 26 Feb 07 jari 1102     
2 26 Feb 07 jari 1103     
2 26 Feb 07 jari 1104     
2 26 Feb 07 jari 1105     //*************** debug *******************
2 26 Feb 07 jari 1106     private void printMatrix(String title, FloatMatrix matrix) {
2 26 Feb 07 jari 1107         System.out.println("===== "+title+" =====");
2 26 Feb 07 jari 1108         matrix.print(5, 2);
2 26 Feb 07 jari 1109     }
2 26 Feb 07 jari 1110     
2 26 Feb 07 jari 1111     private void printFloatArray(String title, float[] floatArray) {
2 26 Feb 07 jari 1112         System.out.println("===== "+title+" =====");
2 26 Feb 07 jari 1113         for (int i=0; i<floatArray.length; i++) {
2 26 Feb 07 jari 1114             System.out.print(floatArray[i]+" ");
2 26 Feb 07 jari 1115         }
2 26 Feb 07 jari 1116         System.out.println();
2 26 Feb 07 jari 1117     }
2 26 Feb 07 jari 1118     
2 26 Feb 07 jari 1119     private void printKernelDiagonal(FloatMatrix matrix){
2 26 Feb 07 jari 1120         for(int i = 0; i < matrix.getRowDimension(); i++){
2 26 Feb 07 jari 1121             System.out.println("Kernal diagonal " + matrix.get(i,i));
2 26 Feb 07 jari 1122         }
2 26 Feb 07 jari 1123     }
2 26 Feb 07 jari 1124     
2 26 Feb 07 jari 1125     private void printWeights(float [] w){
2 26 Feb 07 jari 1126         for(int i = 0; i < w.length; i++){
2 26 Feb 07 jari 1127             System.out.println("Weight = "+ w[i]);
2 26 Feb 07 jari 1128         }
2 26 Feb 07 jari 1129     }
2 26 Feb 07 jari 1130     //**************end debug methods********************
2 26 Feb 07 jari 1131     
2 26 Feb 07 jari 1132 }