/** {@inheritDoc} */ @Override public Object read(final InputStream is) { final BoltzmannMachine result = new BoltzmannMachine(); final EncogReadHelper in = new EncogReadHelper(is); EncogFileSection section; while ((section = in.readNextSection()) != null) { if (section.getSectionName().equals("BOLTZMANN") && section.getSubSectionName().equals("PARAMS")) { final Map<String, String> params = section.parseParams(); result.getProperties().putAll(params); } if (section.getSectionName().equals("BOLTZMANN") && section.getSubSectionName().equals("NETWORK")) { final Map<String, String> params = section.parseParams(); result.setWeights( NumberList.fromList(CSVFormat.EG_FORMAT, params.get(PersistConst.WEIGHTS))); result.setCurrentState( NumberList.fromList(CSVFormat.EG_FORMAT, params.get(PersistConst.OUTPUT))); result.setNeuronCount(EncogFileSection.parseInt(params, PersistConst.NEURON_COUNT)); result.setThreshold( NumberList.fromList(CSVFormat.EG_FORMAT, params.get(PersistConst.THRESHOLDS))); result.setAnnealCycles(EncogFileSection.parseInt(params, BoltzmannMachine.ANNEAL_CYCLES)); result.setRunCycles(EncogFileSection.parseInt(params, BoltzmannMachine.RUN_CYCLES)); result.setTemperature(EncogFileSection.parseDouble(params, PersistConst.TEMPERATURE)); } } return result; }
/** * Create a LMA trainer. * * @param method The method to use. * @param training The training data to use. * @param argsStr The arguments to use. * @return The newly created trainer. */ public MLTrain create(final MLMethod method, final MLDataSet training, final String argsStr) { if (!(method instanceof SOM)) { throw new EncogError( "Neighborhood training cannot be used on a method of type: " + method.getClass().getName()); } final Map<String, String> args = ArchitectureParse.parseParams(argsStr); final ParamsHolder holder = new ParamsHolder(args); final double learningRate = holder.getDouble(MLTrainFactory.PROPERTY_LEARNING_RATE, false, 0.7); final String neighborhoodStr = holder.getString(MLTrainFactory.PROPERTY_NEIGHBORHOOD, false, "rbf"); final String rbfTypeStr = holder.getString(MLTrainFactory.PROPERTY_RBF_TYPE, false, "gaussian"); RBFEnum t; if (rbfTypeStr.equalsIgnoreCase("Gaussian")) { t = RBFEnum.Gaussian; } else if (rbfTypeStr.equalsIgnoreCase("Multiquadric")) { t = RBFEnum.Multiquadric; } else if (rbfTypeStr.equalsIgnoreCase("InverseMultiquadric")) { t = RBFEnum.InverseMultiquadric; } else if (rbfTypeStr.equalsIgnoreCase("MexicanHat")) { t = RBFEnum.MexicanHat; } else { t = RBFEnum.Gaussian; } NeighborhoodFunction nf = null; if (neighborhoodStr.equalsIgnoreCase("bubble")) { nf = new NeighborhoodBubble(1); } else if (neighborhoodStr.equalsIgnoreCase("rbf")) { final String str = holder.getString(MLTrainFactory.PROPERTY_DIMENSIONS, true, null); final int[] size = NumberList.fromListInt(CSVFormat.EG_FORMAT, str); nf = new NeighborhoodRBF(size, t); } else if (neighborhoodStr.equalsIgnoreCase("rbf1d")) { nf = new NeighborhoodRBF1D(t); } if (neighborhoodStr.equalsIgnoreCase("single")) { nf = new NeighborhoodSingle(); } final BasicTrainSOM result = new BasicTrainSOM((SOM) method, learningRate, training, nf); if (args.containsKey(MLTrainFactory.PROPERTY_ITERATIONS)) { final int plannedIterations = holder.getInt(MLTrainFactory.PROPERTY_ITERATIONS, false, 1000); final double startRate = holder.getDouble(MLTrainFactory.PROPERTY_START_LEARNING_RATE, false, 0.05); final double endRate = holder.getDouble(MLTrainFactory.PROPERTY_END_LEARNING_RATE, false, 0.05); final double startRadius = holder.getDouble(MLTrainFactory.PROPERTY_START_RADIUS, false, 10); final double endRadius = holder.getDouble(MLTrainFactory.PROPERTY_END_RADIUS, false, 1); result.setAutoDecay(plannedIterations, startRate, endRate, startRadius, endRadius); } return result; }
/** * Normalize the input file. Write to the specified file. * * @param file The file to write to. */ public void normalize(final File file) { if (this.analyst == null) { throw new EncogError("Can't normalize yet, file has not been analyzed."); } ReadCSV csv = null; PrintWriter tw = null; try { csv = new ReadCSV(getInputFilename().toString(), isExpectInputHeaders(), getFormat()); tw = new PrintWriter(new FileWriter(file)); // write headers, if needed if (isProduceOutputHeaders()) { writeHeaders(tw); } resetStatus(); final int outputLength = this.analyst.determineTotalColumns(); // write file contents while (csv.next() && !shouldStop()) { updateStatus(false); double[] output = AnalystNormalizeCSV.extractFields( this.analyst, this.analystHeaders, csv, outputLength, false); if (this.series.getTotalDepth() > 1) { output = this.series.process(output); } if (output != null) { final StringBuilder line = new StringBuilder(); NumberList.toList(getFormat(), line, output); tw.println(line); } } } catch (final IOException e) { throw new QuantError(e); } finally { reportDone(false); if (csv != null) { try { csv.close(); } catch (final Exception ex) { EncogLogging.log(ex); } } if (tw != null) { try { tw.close(); } catch (final Exception ex) { EncogLogging.log(ex); } } } }
/** {@inheritDoc} */ @Override public ActivationFunction createActivationFunction(String fn) { String name; double[] params; int index = fn.indexOf('['); if (index != -1) { name = fn.substring(0, index).toLowerCase(); int index2 = fn.indexOf(']'); if (index2 == -1) { throw new EncogError("Unbounded [ while parsing activation function."); } String a = fn.substring(index + 1, index2); params = NumberList.fromList(CSVFormat.EG_FORMAT, a); } else { name = fn.toLowerCase(); params = new double[0]; } ActivationFunction af = allocateAF(name); if (af == null) { return null; } if (af.getParamNames().length != params.length) { throw new EncogError( name + " expected " + af.getParamNames().length + ", but " + params.length + " were provided."); } for (int i = 0; i < af.getParamNames().length; i++) { af.setParam(i, params[i]); } return af; }