public void testCorrectness( int i, int cluster, int clusterIndex1, int clusterIndex2, int[] shuffle, int[] mergedClusterSites, double[] lik1, double[] lik2) throws Exception { // System.out.println("Hi!"); int[] tempWeights = new int[tempLikelihood.dataInput.get().getPatternCount()]; tempWeights[tempLikelihood.dataInput.get().getPatternIndex(mergedClusterSites[shuffle[i]])] = 1; tempLikelihood.setPatternWeights(tempWeights); double temp1 = Math.exp( tempLikelihood .calculateLogP( alphaList.getParameter(clusterIndex1).getValue(), invPrList.getParameter(clusterIndex1).getValue(), ratesList.getParameter(clusterIndex1).getValue(), siteModelList.getParameter(clusterIndex1).getValue(), new int[] {mergedClusterSites[shuffle[i]]})[0]); double temp2 = Math.exp( tempLikelihood .calculateLogP( alphaList.getParameter(clusterIndex2).getValue(), invPrList.getParameter(clusterIndex2).getValue(), ratesList.getParameter(clusterIndex2).getValue(), siteModelList.getParameter(clusterIndex2).getValue(), new int[] {mergedClusterSites[shuffle[i]]})[0]); if (temp1 != lik1[shuffle[i]] || temp2 != lik2[shuffle[i]]) { System.out.println("temp1"); System.out.println("shuffle_i: " + shuffle[i]); System.out.println("mergedClusterSites[shuffle]: " + mergedClusterSites[shuffle[i]]); System.out.println("cluster: " + cluster); System.out.println(+mergedClusterSites.length + " " + lik1.length); for (int j = 0; j < lik1.length; j++) { System.out.println("merged lik1: " + mergedClusterSites[j] + " " + lik1[j]); } for (int j = 0; j < lik2.length; j++) { System.out.println("merged lik2: " + mergedClusterSites[j] + " " + lik2[j]); } throw new RuntimeException( temp1 + " " + lik1[shuffle[i]] + " " + temp2 + " " + lik2[shuffle[i]]); } }
public double merge( int index1, int index2, int clusterIndex1, int clusterIndex2, int[] cluster1Sites, int[] cluster2Sites) { /*if(Math.abs(modelList.getParameter(clusterIndex1).getValue() - modelList.getParameter(clusterIndex2).getValue()) > 1.0){ return Double.NEGATIVE_INFINITY; }*/ double logqMerge = 0.0; HashMap<Integer, Integer> siteMap = new HashMap<Integer, Integer>(); // The value of the merged cluster will have that of cluster 2 before the merge. QuietRealParameter mergedRates = ratesList.getParameter(clusterIndex2); QuietRealParameter mergedAlpha = alphaList.getParameter(clusterIndex2); QuietRealParameter mergedInvPr = invPrList.getParameter(clusterIndex2); QuietRealParameter mergedSiteModel = siteModelList.getParameter(clusterIndex2); // Create a vector that combines the site indices of the two clusters int[] mergedClusterSites = new int[cluster1Sites.length + cluster2Sites.length - 2]; int k = 0; for (int i = 0; i < cluster1Sites.length; i++) { if (cluster1Sites[i] != index1) { // For all members that are not index 1, // record the cluster in which they have been before the merge, // and assign them to the combined vector. siteMap.put(cluster1Sites[i], clusterIndex1); mergedClusterSites[k++] = cluster1Sites[i]; } } for (int i = 0; i < cluster2Sites.length; i++) { // All members in cluster 2 remains in cluster2 so no new pointer assignments if (cluster2Sites[i] != index2) { // For all members that are not index 2, // record the cluster in which they have been before the merge, // and assign them to the combined vector. siteMap.put(cluster2Sites[i], clusterIndex2); try { mergedClusterSites[k++] = cluster2Sites[i]; } catch (Exception e) { System.out.println("k: " + k); System.out.println("i: " + i); System.out.println("cluster2Sites.length: " + cluster2Sites.length); System.out.println("index2: " + index2); for (int index : cluster2Sites) { System.out.print(index + " "); } System.out.println(); throw new RuntimeException(""); } } } try { // Create a weight vector of patterns to inform the temporary tree likelihood // which set of pattern likelihoods are to be computed. // int[] tempWeights = dpTreeLikelihood.getClusterWeights(clusterIndex1); /*int[] tempWeights = new int[tempLikelihood.m_data.get().getPatternCount()]; for(int i = 0; i < cluster1Sites.length; i++){ int patIndex = tempLikelihood.m_data.get().getPatternIndex(cluster1Sites[i]); tempWeights[patIndex] = 1; } tempLikelihood.setPatternWeights(tempWeights); double[] cluster1SitesCluster2ParamLogLik = tempLikelihood.calculateLogP( mergedParam, mergedModel, mergedFreqs, cluster1Sites, index1 );*/ k = 0; int[] sCluster1Sites = new int[cluster1Sites.length - 1]; for (int i = 0; i < cluster1Sites.length; i++) { if (cluster1Sites[i] != index1) { sCluster1Sites[k++] = cluster1Sites[i]; } } tempLikelihood.setupPatternWeightsFromSites(sCluster1Sites); double[] cluster1SitesCluster2ParamLogLik = tempLikelihood.calculateLogP( mergedAlpha.getValue(), mergedInvPr.getValue(), mergedRates.getValue(), mergedSiteModel.getValue(), sCluster1Sites); // tempWeights = dpTreeLikelihood.getClusterWeights(clusterIndex2); /*tempWeights = new int[tempLikelihood.m_data.get().getPatternCount()]; for(int i = 0; i < cluster2Sites.length; i++){ int patIndex = tempLikelihood.m_data.get().getPatternIndex(cluster2Sites[i]); tempWeights[patIndex] = 1; } tempLikelihood.setPatternWeights(tempWeights); QuietRealParameter removedParam = paramList.getParameter(clusterIndex1); QuietRealParameter removedModel = modelList.getParameter(clusterIndex1); QuietRealParameter removedFreqs = freqsList.getParameter(clusterIndex1); double[] cluster2SitesCluster1ParamLogLik = tempLikelihood.calculateLogP( removedParam, removedModel, removedFreqs, cluster2Sites, index2 ); */ k = 0; int[] sCluster2Sites = new int[cluster2Sites.length - 1]; for (int i = 0; i < cluster2Sites.length; i++) { if (cluster2Sites[i] != index2) { sCluster2Sites[k++] = cluster2Sites[i]; } } tempLikelihood.setupPatternWeightsFromSites(sCluster2Sites); QuietRealParameter removedAlpha = alphaList.getParameter(clusterIndex1); QuietRealParameter removedInvPr = invPrList.getParameter(clusterIndex1); QuietRealParameter removedRates = ratesList.getParameter(clusterIndex1); QuietRealParameter removedSiteModel = siteModelList.getParameter(clusterIndex1); double[] cluster2SitesCluster1ParamLogLik = tempLikelihood.calculateLogP( removedAlpha.getValue(), removedInvPr.getValue(), removedRates.getValue(), removedSiteModel.getValue(), sCluster2Sites); // System.out.println("populate logLik1:"); double[] logLik1 = new double[mergedClusterSites.length]; for (int i = 0; i < (cluster1Sites.length - 1); i++) { // System.out.println(clusterIndex1+" "+mergedClusterSites[i]); // logLik1[i] = dpTreeLikelihood.getSiteLogLikelihood(clusterIndex1,mergedClusterSites[i]); logLik1[i] = getSiteLogLikelihood(removedRates.getIDNumber(), clusterIndex1, mergedClusterSites[i]); } /*System.out.println(cluster2SitesCluster1ParamLogLik.length); System.out.println(logLik1.length); System.out.println(cluster1Sites.length-1); System.out.println(cluster2SitesCluster1ParamLogLik.length);*/ System.arraycopy( cluster2SitesCluster1ParamLogLik, 0, logLik1, cluster1Sites.length - 1, cluster2SitesCluster1ParamLogLik.length); double[] logLik2 = new double[mergedClusterSites.length]; System.arraycopy( cluster1SitesCluster2ParamLogLik, 0, logLik2, 0, cluster1SitesCluster2ParamLogLik.length); // System.out.println("populate logLik2:"); for (int i = cluster1SitesCluster2ParamLogLik.length; i < logLik2.length; i++) { // System.out.println(clusterIndex2+" // "+mergedClusterSites[i-cluster1SitesCluster2ParamLogLik.length]); // logLik2[i] = dpTreeLikelihood.getSiteLogLikelihood(clusterIndex2,mergedClusterSites[i]); logLik2[i] = getSiteLogLikelihood(mergedRates.getIDNumber(), clusterIndex2, mergedClusterSites[i]); } double[] lik1 = new double[logLik1.length]; double[] lik2 = new double[logLik2.length]; double maxLog; // scale it so it may be more accurate for (int i = 0; i < logLik1.length; i++) { maxLog = Math.max(logLik1[i], logLik2[i]); // System.out.println(i+" "+logLik1[i]+" "+logLik2[i]); if (Math.exp(maxLog) < 1e-100) { if (maxLog == logLik1[i]) { lik1[i] = 1.0; lik2[i] = Math.exp(logLik2[i] - maxLog); // System.out.println(i+" "+lik1[i]+" "+lik2[i]); } else { lik1[i] = Math.exp(logLik1[i] - maxLog); lik2[i] = 1.0; // System.out.println(i+" "+lik1[i]+" "+lik2[i]); } } else { lik1[i] = Math.exp(logLik1[i]); lik2[i] = Math.exp(logLik2[i]); } } /*for(int i = 0; i < logLik1.length; i++){ if(Double.isNaN(logLik1[i])){ //System.out.println("logLik1: "+logLik1[i]); return Double.NEGATIVE_INFINITY; } if(Double.isNaN(logLik2[i])){ //System.out.println("logLik2: "+logLik2[i]); return Double.NEGATIVE_INFINITY; } lik1[i] = Math.exp(logLik1[i]); lik2[i] = Math.exp(logLik2[i]); //System.out.println(lik1[i]+" "+lik2[i]); } */ // Create a set of indices for random permutation int[] shuffle = new int[mergedClusterSites.length]; for (int i = 0; i < shuffle.length; i++) { shuffle[i] = i; } Randomizer.shuffle(shuffle); int cluster1Count = 1; int cluster2Count = 1; int cluster; double psi1, psi2, cluster1Prob; for (int i = 0; i < mergedClusterSites.length; i++) { cluster = siteMap.get(mergedClusterSites[shuffle[i]]); psi1 = cluster1Count * lik1[shuffle[i]]; psi2 = cluster2Count * lik2[shuffle[i]]; // System.out.println(psi1+" "+psi2); if (testCorrect) { testCorrectness( i, cluster, clusterIndex1, clusterIndex2, shuffle, mergedClusterSites, lik1, lik2); } cluster1Prob = psi1 / (psi1 + psi2); if (cluster == clusterIndex1) { logqMerge += Math.log(cluster1Prob); cluster1Count++; } else if (cluster == clusterIndex2) { logqMerge += Math.log(1 - cluster1Prob); cluster2Count++; } else { throw new RuntimeException("Something is wrong."); } } logqMerge += // paramBaseDistr.calcLogP(removedParam)+ mergeValueInLogSpace(removedRates, mergedRates, ratesBaseDistr) + mergeValueInLogSpace(removedAlpha, mergedAlpha, alphaBaseDistr) + mergeValue(removedInvPr, mergedInvPr, invPrBaseDistr) + // modelBaseDistr.calcLogP(removedModel)+ mergeDiscreteValue(removedSiteModel, mergedSiteModel, siteModelBaseDistr); if (logqMerge > Double.NEGATIVE_INFINITY) { ratesList.mergeParameter(clusterIndex1, clusterIndex2); alphaList.mergeParameter(clusterIndex1, clusterIndex2); invPrList.mergeParameter(clusterIndex1, clusterIndex2); siteModelList.mergeParameter(clusterIndex1, clusterIndex2); for (int i = 0; i < cluster1Sites.length; i++) { // Point every member in cluster 1 to cluster 2 ratesPointers.point(cluster1Sites[i], mergedRates); } } } catch (Exception e) { throw new RuntimeException(e); } return logqMerge; }
public double split(int index1, int index2, int clusterIndex, int[] initClusterSites) { try { double logqSplit = 0.0; // Create a parameter by sampling from the prior // QuietRealParameter newParam = getSample(paramBaseDistr, paramList.getUpper(), // paramList.getLower()); QuietRealParameter newRates = new QuietRealParameter(new Double[1]); logqSplit += proposeNewValueInLogSpace( newRates, ratesList.getValue(clusterIndex, 0), ratesBaseDistr, ratesList.getUpper(), ratesList.getLower()); ratesList.getValues(clusterIndex); // QuietRealParameter newModel = getSample(modelBaseDistr, modelList.getUpper(), // modelList.getLower()); QuietRealParameter newAlpha = new QuietRealParameter(new Double[1]); logqSplit += proposeNewValueInLogSpace( newAlpha, alphaList.getValue(clusterIndex, 0), alphaBaseDistr, alphaList.getUpper(), alphaList.getLower()); // QuietRealParameter newFreqs = getSample(freqsBaseDistr, freqsList.getUpper(), // freqsList.getLower()); QuietRealParameter newInvPr = new QuietRealParameter(new Double[1]); logqSplit += proposeNewValue( newInvPr, invPrList.getValues(clusterIndex), invPrBaseDistr, invPrList.getUpper(), invPrList.getLower()); QuietRealParameter newSiteModel = new QuietRealParameter(new Double[1]); logqSplit += proposeDiscreteValue( newSiteModel, siteModelList.getValue(clusterIndex, 0), siteModelBaseDistr, siteModelList.getUpper(), siteModelList.getLower()); // Perform a split // paramList.splitParameter(clusterIndex,newParam); // modelList.splitParameter(clusterIndex,newModel); // freqsList.splitParameter(clusterIndex,newFreqs); // Remove the index 1 and index 2 from the cluster int[] clusterSites = new int[initClusterSites.length - 2]; int k = 0; for (int i = 0; i < initClusterSites.length; i++) { if (initClusterSites[i] != index1 && initClusterSites[i] != index2) { clusterSites[k++] = initClusterSites[i]; } } // Form a new cluster with index 1 // paramPointers.point(index1,newParam); // modelPointers.point(index1,newModel); // freqsPointers.point(index1,newFreqs); // Shuffle the cluster_-{index_1,index_2} to obtain a random permutation Randomizer.shuffle(clusterSites); // Create the weight vector of site patterns according to the order of the shuffled index. /*int[] tempWeights = new int[tempLikelihood.m_data.get().getPatternCount()]; int patIndex; for(int i = 0; i < clusterSites.length; i++){ patIndex = tempLikelihood.m_data.get().getPatternIndex(clusterSites[i]); tempWeights[patIndex] = 1; }*/ tempLikelihood.setupPatternWeightsFromSites(clusterSites); // Site log likelihoods in the order of the shuffled sites double[] logLik1 = tempLikelihood.calculateLogP( newAlpha.getValue(), newInvPr.getValue(), newRates.getValue(), newSiteModel.getValue(), clusterSites); double[] logLik2 = new double[clusterSites.length]; for (int i = 0; i < logLik2.length; i++) { // logLik2[i] = dpTreeLikelihood.getSiteLogLikelihood(clusterIndex,clusterSites[i]); logLik2[i] = getSiteLogLikelihood( ratesList.getParameterIDNumber(clusterIndex), clusterIndex, clusterSites[i]); } double[] lik1 = new double[logLik1.length]; double[] lik2 = new double[logLik2.length]; double maxLog; // scale it so it may be more accurate for (int i = 0; i < logLik1.length; i++) { maxLog = Math.min(logLik1[i], logLik2[i]); if (Math.exp(maxLog) < 1e-100) { if (maxLog == logLik1[i]) { lik1[i] = 1.0; lik2[i] = Math.exp(logLik2[i] - maxLog); } else { lik1[i] = Math.exp(logLik1[i] - maxLog); lik2[i] = 1.0; } } else { lik1[i] = Math.exp(logLik1[i]); lik2[i] = Math.exp(logLik2[i]); } } /*boolean ohCrap = false; for(int i = 0; i < logLik1.length; i++){ if(Double.isNaN(logLik1[i])){ return Double.NEGATIVE_INFINITY; //ohCrap = true; //System.out.println("logLik1: "+logLik1); //logLik1[i] = Double.NEGATIVE_INFINITY; } if(Double.isNaN(logLik2[i])){ return Double.NEGATIVE_INFINITY; //ohCrap = true; //System.out.println("logLik1: "+logLik2); //logLik2[i] = Double.NEGATIVE_INFINITY; } lik1[i] = Math.exp(logLik1[i]); lik2[i] = Math.exp(logLik2[i]); //System.out.println(lik1[i]+" "+lik2[i]); } if(ohCrap){ for(int i = 0; i < newRates.getDimension();i++){ System.out.print(newRates.getValue(i)+" "); } System.out.println(); } */ /*for(int i = 0; i < clusterSites.length;i++){ System.out.println("clusterSites: "+clusterSites[i]); } System.out.println("index 1: "+index1+" index2: "+index2);*/ int cluster1Count = 1; int cluster2Count = 1; // Assign members of the existing cluster (except for indice 1 and 2) randomly // to the existing and the new cluster double psi1, psi2, newClusterProb, draw; int[] newAssignment = new int[clusterSites.length]; for (int i = 0; i < clusterSites.length; i++) { psi1 = cluster1Count * lik1[i]; psi2 = cluster2Count * lik2[i]; newClusterProb = psi1 / (psi1 + psi2); draw = Randomizer.nextDouble(); if (draw < newClusterProb) { // System.out.println("in new cluster: "+clusterSites[i]); // paramPointers.point(clusterSites[i],newParam); // modelPointers.point(clusterSites[i],newModel); // freqsPointers.point(clusterSites[i],newFreqs); newAssignment[cluster1Count - 1] = clusterSites[i]; logqSplit += Math.log(newClusterProb); cluster1Count++; } else { logqSplit += Math.log(1.0 - newClusterProb); cluster2Count++; } } // logqSplit += //paramBaseDistr.calcLogP(newParam) + // modelBaseDistr.calcLogP(newModel)+ // freqsBaseDistr.calcLogP(newFreqs); if (-logqSplit > Double.NEGATIVE_INFINITY) { ratesList = ratesListInput.get(this); alphaList = alphaListInput.get(this); invPrList = invPrListInput.get(this); siteModelList = siteModelListInput.get(this); ratesPointers = ratesPointersInput.get(this); // Perform a split ratesList.splitParameter(clusterIndex, newRates); alphaList.splitParameter(clusterIndex, newAlpha); invPrList.splitParameter(clusterIndex, newInvPr); siteModelList.splitParameter(clusterIndex, newSiteModel); // Form a new cluster with index 1 ratesPointers.point(index1, newRates); for (int i = 0; i < (cluster1Count - 1); i++) { ratesPointers.point(newAssignment[i], newRates); } } return -logqSplit; } catch (Exception e) { // freqsBaseDistr.printDetails(); throw new RuntimeException(e); } }