public double split(int index1, int index2, int clusterIndex, int[] initClusterSites) { try { double logqSplit = 0.0; // Create a parameter by sampling from the prior // QuietRealParameter newParam = getSample(paramBaseDistr, paramList.getUpper(), // paramList.getLower()); QuietRealParameter newParam = new QuietRealParameter(new Double[5]); // logqSplit += proposeNewValue(newParam, paramBaseDistr, paramList.getUpper(), // paramList.getLower()); double[] oldParamValues = new double[5]; for (int i = 0; i < oldParamValues.length; i++) { oldParamValues[i] = paramList.getValue(clusterIndex, i); } logqSplit += proposeNewValue2( newParam, oldParamValues, paramBaseDistr, paramList.getUpper(), paramList.getLower()); // QuietRealParameter newModel = getSample(modelBaseDistr, modelList.getUpper(), // modelList.getLower()); QuietRealParameter newModel = new QuietRealParameter(new Double[1]); logqSplit += proposeDiscreteValue( newModel, modelList.getValue(clusterIndex, 0), modelDistrInput.get(), modelList.getUpper(), modelList.getLower()); QuietRealParameter newFreqs = getSample(freqsBaseDistr, freqsList.getUpper(), freqsList.getLower()); // QuietRealParameter newRates = getSample(ratesBaseDistr, ratesList.getUpper(), // ratesList.getLower()); QuietRealParameter newRates = new QuietRealParameter(new Double[1]); logqSplit += proposalValueInLogSpace( newRates, ratesList.getValue(clusterIndex, 0), ratesBaseDistr, ratesList.getUpper(), ratesList.getLower()); // Remove the index 1 and index 2 from the cluster int[] clusterSites = new int[initClusterSites.length - 2]; int k = 0; for (int i = 0; i < initClusterSites.length; i++) { if (initClusterSites[i] != index1 && initClusterSites[i] != index2) { clusterSites[k++] = initClusterSites[i]; } } // Form a new cluster with index 1 paramPointers.point(index1, newParam); modelPointers.point(index1, newModel); freqsPointers.point(index1, newFreqs); ratesPointers.point(index1, newRates); // Shuffle the cluster_-{index_1,index_2} to obtain a random permutation Randomizer.shuffle(clusterSites); // Create the weight vector of site patterns according to the order of the shuffled index. /*int[] tempWeights = new int[tempLikelihood.m_data.get().getPatternCount()]; int patIndex; for(int i = 0; i < clusterSites.length; i++){ patIndex = tempLikelihood.m_data.get().getPatternIndex(clusterSites[i]); tempWeights[patIndex] = 1; } tempLikelihood.setPatternWeights(tempWeights);*/ tempLikelihood.setupPatternWeightsFromSites(clusterSites); // Site log likelihoods in the order of the shuffled sites double[] logLik1 = tempLikelihood.calculateLogP(newParam, newModel, newFreqs, newRates, clusterSites); double[] logLik2 = new double[clusterSites.length]; for (int i = 0; i < logLik2.length; i++) { logLik2[i] = dpTreeLikelihood.getSiteLogLikelihood(clusterIndex, clusterSites[i]); } double[] lik1 = new double[logLik1.length]; double[] lik2 = new double[logLik2.length]; double minLog; // scale it so it may be more accurate /*for(int i = 0; i < logLik1.length; i++){ minLog = Math.min(logLik1[i],logLik2[i]); if(minLog == logLik1[i]){ lik1[i] = 1.0; lik2[i] = Math.exp(logLik2[i] - minLog); }else{ lik1[i] = Math.exp(logLik1[i] - minLog); lik2[i] = 1.0; } }*/ for (int i = 0; i < logLik1.length; i++) { lik1[i] = Math.exp(logLik1[i]); lik2[i] = Math.exp(logLik2[i]); // System.out.println(lik1[i]+" "+lik2[i]); } /*for(int i = 0; i < clusterSites.length;i++){ System.out.println("clusterSites: "+clusterSites[i]); } System.out.println("index 1: "+index1+" index2: "+index2);*/ int cluster1Count = 1; int cluster2Count = 1; int[] sitesInCluster1 = new int[initClusterSites.length]; sitesInCluster1[0] = index1; // Assign members of the existing cluster (except for indice 1 and 2) randomly // to the existing and the new cluster double psi1, psi2, newClusterProb, draw; for (int i = 0; i < clusterSites.length; i++) { psi1 = cluster1Count * lik1[i]; psi2 = cluster2Count * lik2[i]; newClusterProb = psi1 / (psi1 + psi2); draw = Randomizer.nextDouble(); if (draw < newClusterProb) { // System.out.println("in new cluster: "+clusterSites[i]); sitesInCluster1[cluster1Count] = clusterSites[i]; // paramPointers.point(clusterSites[i],newParam); // modelPointers.point(clusterSites[i],newModel); // freqsPointers.point(clusterSites[i],newFreqs); // ratesPointers.point(clusterSites[i],newRates); logqSplit += Math.log(newClusterProb); cluster1Count++; } else { logqSplit += Math.log(1.0 - newClusterProb); cluster2Count++; } } // logqSplit += paramBaseDistr.calcLogP(newParam) logqSplit += // modelBaseDistr.calcLogP(newModel) + freqsBaseDistr.calcLogP(newFreqs) // + ratesBaseDistr.calcLogP(newRates) ; // Perform a split paramList = paramListInput.get(this); modelList = modelListInput.get(this); freqsList = freqsListInput.get(this); ratesList = ratesListInput.get(this); paramPointers = paramPointersInput.get(this); modelPointers = modelPointersInput.get(this); freqsPointers = freqsPointersInput.get(this); ratesPointers = ratesPointersInput.get(this); paramList.splitParameter(clusterIndex, newParam); modelList.splitParameter(clusterIndex, newModel); freqsList.splitParameter(clusterIndex, newFreqs); ratesList.splitParameter(clusterIndex, newRates); // Form a new cluster with index 1 paramPointers = paramPointersInput.get(this); modelPointers = modelPointersInput.get(this); freqsPointers = freqsPointersInput.get(this); ratesPointers = ratesPointersInput.get(this); for (int i = 0; i < cluster1Count; i++) { paramPointers.point(sitesInCluster1[i], newParam); modelPointers.point(sitesInCluster1[i], newModel); freqsPointers.point(sitesInCluster1[i], newFreqs); ratesPointers.point(sitesInCluster1[i], newRates); } return -logqSplit; } catch (Exception e) { throw new RuntimeException(e); } }
public double split(int index1, int index2, int clusterIndex, int[] initClusterSites) { try { double logqSplit = 0.0; // Create a parameter by sampling from the prior QuietRealParameter newParam = getSample(paramBaseDistr, paramList.getUpper(), paramList.getLower()); QuietRealParameter newModel = getSample(modelBaseDistr, modelList.getUpper(), modelList.getLower()); QuietRealParameter newFreqs = getSample(freqsBaseDistr, freqsList.getUpper(), freqsList.getLower()); // Perform a split // paramList.splitParameter(clusterIndex,newParam); // modelList.splitParameter(clusterIndex,newModel); // freqsList.splitParameter(clusterIndex,newFreqs); // Remove the index 1 and index 2 from the cluster int[] clusterSites = new int[initClusterSites.length - 2]; int k = 0; for (int i = 0; i < initClusterSites.length; i++) { if (initClusterSites[i] != index1 && initClusterSites[i] != index2) { clusterSites[k++] = initClusterSites[i]; } } // Form a new cluster with index 1 // paramPointers.point(index1,newParam); // modelPointers.point(index1,newModel); // freqsPointers.point(index1,newFreqs); // Shuffle the cluster_-{index_1,index_2} to obtain a random permutation Randomizer.shuffle(clusterSites); // Create the weight vector of site patterns according to the order of the shuffled index. /*int[] tempWeights = new int[tempLikelihood.m_data.get().getPatternCount()]; int patIndex; for(int i = 0; i < clusterSites.length; i++){ patIndex = tempLikelihood.m_data.get().getPatternIndex(clusterSites[i]); tempWeights[patIndex] = 1; }*/ tempLikelihood.setupPatternWeightsFromSites(clusterSites); // Site log likelihoods in the order of the shuffled sites double[] logLik1 = tempLikelihood.calculateLogP(newParam, newModel, newFreqs, clusterSites); double[] logLik2 = new double[clusterSites.length]; for (int i = 0; i < logLik2.length; i++) { // logLik2[i] = dpTreeLikelihood.getSiteLogLikelihood(clusterIndex,clusterSites[i]); logLik2[i] = getSiteLogLikelihood( paramList.getParameter(clusterIndex).getIDNumber(), clusterIndex, clusterSites[i]); } double[] lik1 = new double[logLik1.length]; double[] lik2 = new double[logLik2.length]; double maxLog; // scale it so it may be more accurate for (int i = 0; i < logLik1.length; i++) { maxLog = Math.max(logLik1[i], logLik2[i]); if (Math.exp(maxLog) < 1e-100) { if (maxLog == logLik1[i]) { lik1[i] = 1.0; lik2[i] = Math.exp(logLik2[i] - maxLog); } else { lik1[i] = Math.exp(logLik1[i] - maxLog); lik2[i] = 1.0; } } else { lik1[i] = Math.exp(logLik1[i]); lik2[i] = Math.exp(logLik2[i]); } } /*boolean ohCrap = false; for(int i = 0; i < logLik1.length; i++){ if(Double.isNaN(logLik1[i])){ return Double.NEGATIVE_INFINITY; //ohCrap = true; //System.out.println("logLik1: "+logLik1); //logLik1[i] = Double.NEGATIVE_INFINITY; } if(Double.isNaN(logLik2[i])){ return Double.NEGATIVE_INFINITY; //ohCrap = true; //System.out.println("logLik1: "+logLik2); //logLik2[i] = Double.NEGATIVE_INFINITY; } lik1[i] = Math.exp(logLik1[i]); lik2[i] = Math.exp(logLik2[i]); //System.out.println(lik1[i]+" "+lik2[i]); } if(ohCrap){ for(int i = 0; i < newParam.getDimension();i++){ System.out.print(newParam.getValue(i)+" "); } System.out.println(); }*/ /*for(int i = 0; i < clusterSites.length;i++){ System.out.println("clusterSites: "+clusterSites[i]); } System.out.println("index 1: "+index1+" index2: "+index2);*/ int cluster1Count = 1; int cluster2Count = 1; // Assign members of the existing cluster (except for indice 1 and 2) randomly // to the existing and the new cluster double psi1, psi2, newClusterProb, draw; int[] newAssignment = new int[clusterSites.length]; for (int i = 0; i < clusterSites.length; i++) { psi1 = cluster1Count * lik1[i]; psi2 = cluster2Count * lik2[i]; newClusterProb = psi1 / (psi1 + psi2); draw = Randomizer.nextDouble(); if (draw < newClusterProb) { // System.out.println("in new cluster: "+clusterSites[i]); // paramPointers.point(clusterSites[i],newParam); // modelPointers.point(clusterSites[i],newModel); // freqsPointers.point(clusterSites[i],newFreqs); newAssignment[cluster1Count - 1] = clusterSites[i]; logqSplit += Math.log(newClusterProb); cluster1Count++; } else { logqSplit += Math.log(1.0 - newClusterProb); cluster2Count++; } } // System.out.println("halfway: "+logqSplit); logqSplit += paramBaseDistr.calcLogP(newParam) + modelBaseDistr.calcLogP(newModel) + freqsBaseDistr.calcLogP(newFreqs); if (-logqSplit > Double.NEGATIVE_INFINITY) { paramList = paramListInput.get(this); modelList = modelListInput.get(this); freqsList = freqsListInput.get(this); paramPointers = paramPointersInput.get(this); modelPointers = modelPointersInput.get(this); freqsPointers = freqsPointersInput.get(this); // Perform a split paramList.splitParameter(clusterIndex, newParam); modelList.splitParameter(clusterIndex, newModel); freqsList.splitParameter(clusterIndex, newFreqs); // Form a new cluster with index 1 paramPointers.point(index1, newParam); modelPointers.point(index1, newModel); freqsPointers.point(index1, newFreqs); for (int i = 0; i < (cluster1Count - 1); i++) { paramPointers.point(newAssignment[i], newParam); modelPointers.point(newAssignment[i], newModel); freqsPointers.point(newAssignment[i], newFreqs); } } return -logqSplit; } catch (Exception e) { // freqsBaseDistr.printDetails(); throw new RuntimeException(e); } }