/** * Test the behaviour of a synapse model. * * @param network A neural network containing two neurons and a single synapse to be tested. The * neurons at index 0 and 1 should be the pre- and post-synaptic neurons respectively, and are * typically configured to produce a fixed firing pattern (though any neuron type and firing * pattern are permitted). * @param simSteps The number of steps to run the simulation for. * @param logSpikesAndStateVariables Whether to record pre- and post-synaptic spikes and any state * variables exposed by the synapse model in the test results. * @param simStepsNoSpikes The number of steps to run the simulation for with no spiking after the * normal spike test. This is useful for testing models which change the efficacy in the * absence of spikes. * @return For a single spike protocol, a TestResults object with type {@link TYPE#STDP} * consisting of series labelled "Time" and "Efficacy" and if logSpikesAndStateVariables == * true then also "Pre-synaptic spikes", "Post-synaptic spikes" and any state variables * exposed by the synapse model. */ public static TestResults singleTest( NeuralNetwork network, long simSteps, boolean logSpikesAndStateVariables, long simStepsNoSpikes) { if (network.getNeurons().getSize() != 2 || network.getSynapses().getSize() != 1) { throw new IllegalArgumentException( "The neural network must contain at least 2 neurons and 1 synapse."); } simStepsNoSpikes = Math.max(0, simStepsNoSpikes); int displayTimeResolution = Math.min(1000, network.getTimeResolution()); int logStepCount = network.getTimeResolution() / displayTimeResolution; int logSize = (int) ((simSteps + simStepsNoSpikes) / logStepCount); double[] timeLog = new double[logSize]; double[] efficacyLog = new double[logSize]; double[][] prePostLogs = null, traceLogs = null; double[] stateVars; if (logSpikesAndStateVariables) { prePostLogs = new double[2][logSize]; traceLogs = new double[network.getSynapses().getStateVariableNames().length][logSize]; } int logIndex = 0; long step = 0; for (; step < simSteps; step++) { double time = network.getTime(); network.step(); if (step % logStepCount == 0) { timeLog[logIndex] = time; efficacyLog[logIndex] = network.getSynapses().getEfficacy(0); // If we're only testing a few repetitions, include some extra // data. if (logSpikesAndStateVariables) { prePostLogs[0][logIndex] = network.getNeurons().getOutput(0); prePostLogs[1][logIndex] = network.getNeurons().getOutput(1); stateVars = network.getSynapses().getStateVariableValues(0); for (int v = 0; v < stateVars.length; v++) { traceLogs[v][logIndex] = stateVars[v]; } } logIndex++; } } if (simStepsNoSpikes > 0) { FixedProtocolNeuronConfiguration config = new FixedProtocolNeuronConfiguration(100, new double[] {}); network.getNeurons().setConfiguration(0, config); network.getNeurons().setComponentConfiguration(0, 0); network.getNeurons().setComponentConfiguration(1, 0); for (; step < simSteps + simStepsNoSpikes; step++) { double time = network.getTime(); network.step(); if (step % logStepCount == 0) { timeLog[logIndex] = time; efficacyLog[logIndex] = network.getSynapses().getEfficacy(0); // If we're only testing a few repetitions, include some extra data. if (logSpikesAndStateVariables) { prePostLogs[0][logIndex] = network.getNeurons().getOutput(0); prePostLogs[1][logIndex] = network.getNeurons().getOutput(1); stateVars = network.getSynapses().getStateVariableValues(0); for (int v = 0; v < stateVars.length; v++) { traceLogs[v][logIndex] = stateVars[v]; } } logIndex++; } } } TestResults results = new TestResults(); results.setProperty("type", TYPE.STDP); results.addResult("Efficacy", efficacyLog); results.addResult("Time", timeLog); if (logSpikesAndStateVariables) { results.addResult("Pre-synaptic spikes", prePostLogs[0]); results.addResult("Post-synaptic spikes", prePostLogs[1]); String[] stateVariableNames = network.getSynapses().getStateVariableNames(); for (int v = 0; v < stateVariableNames.length; v++) { results.addResult(stateVariableNames[v], traceLogs[v]); } } return results; }
/** * Test a synapse on the specified spiking protocol or a series of spiking protocols derived from * initial and final protocols by interpolation over one or two dimensions. * * @param synapse The SynapseCollection containing the synapse to test (the first synapse is * used). * @param timeResolution The time resolution to use in the simulation, see {@link * com.ojcoleman.bain.NeuralNetwork} * @param period The period of the spike pattern in seconds. * @param repetitions The number of times to apply the spike pattern. * @param patterns Array containing spike patterns, in the form [initial, dim 1, dim 2][pre, * post][spike number] = spike time. The [spike number] array contains the times (s) of each * spike, relative to the beginning of the pattern. See {@link * com.ojcoleman.bain.neuron.spiking.FixedProtocolNeuronCollection}. * @param refSpikeIndexes Array specifying indexes of the two spikes to use as timing variation * references for each variation dimension, in the form [dim 1, dim 2][reference spike, * relative spike] = spike index. * @param refSpikePreOrPost Array specifying whether the timing variation reference spikes * specified by refSpikeIndexes belong to the pre- or post-synaptic neurons, in the form [dim * 1, dim 2][base spike, relative spike] = Constants.PRE or Constants.POST. * @param logSpikesAndStateVariables Whether to record pre- and post-synaptic spikes and any state * variables exposed by the synapse model in the test results. * @param progressMonitor If not null, this will be updated with the current progress. * @return For a single spike protocol, a TestResults object with type {@link TYPE#STDP} * consisting of series labelled "Time" and "Efficacy" and if logSpikesAndStateVariables == * true then also "Pre-synaptic spikes", "Post-synaptic spikes" and any state variables * exposed by the synapse model. For a protocol varied over one dimension, a TestResults * object with type {@link TYPE#STDP_1D} consisting of series labelled "Time delta" and * "Efficacy". For a protocol varied over two dimensions, a TestResults object with type * {@link TYPE#STDP_2D} consisting of series labelled "Time delta 1", "Time delta 2" and * "Efficacy". */ public static TestResults testPattern( SynapseCollection<? extends ComponentConfiguration> synapse, int timeResolution, double period, int repetitions, double[][][] patterns, int[][] refSpikeIndexes, int[][] refSpikePreOrPost, boolean logSpikesAndStateVariables, ProgressMonitor progressMonitor) throws IllegalArgumentException { int variationDimsCount = patterns.length - 1; // Number of dimensions over which spike timing patterns vary. if (variationDimsCount > 2) { throw new IllegalArgumentException( "The number of variation dimensions may not exceed 2 (patterns.length must be <= 3)"); } if (progressMonitor != null) { progressMonitor.setMinimum(0); } TestResults results = new TestResults(); FixedProtocolNeuronCollection neurons = new FixedProtocolNeuronCollection(2); FixedProtocolNeuronConfiguration preConfig = new FixedProtocolNeuronConfiguration(period, patterns[0][0]); neurons.addConfiguration(preConfig); neurons.setComponentConfiguration(0, 0); FixedProtocolNeuronConfiguration postConfig = new FixedProtocolNeuronConfiguration(period, patterns[0][1]); neurons.addConfiguration(postConfig); neurons.setComponentConfiguration(1, 1); synapse.setPreNeuron(0, 0); synapse.setPostNeuron(0, 1); NeuralNetwork sim = new NeuralNetwork(timeResolution, neurons, synapse); int simSteps = (int) Math.round(period * repetitions * timeResolution); int displayTimeResolution = Math.min(1000, timeResolution); results.setProperty("simulation time resolution", timeResolution); results.setProperty("display time resolution", displayTimeResolution); // long startTime = System.currentTimeMillis(); // If we're just testing a single spike pattern. // Handle separately as logging is quite // different from testing spike // patterns with gradually altered spike times. if (variationDimsCount == 0) { results = singleTest(sim, simSteps, logSpikesAndStateVariables, 0); } else { // We're testing spike patterns with gradually altered spike times over one or two // dimensions. int[] spikeCounts = {patterns[0][0].length, patterns[0][1].length}; // The initial and final time deltas (s), given base and relative spike times in initial and // final spike patterns, // for each variation dimension. double[] timeDeltaInitial = new double[2], timeDeltaFinal = new double[2]; // The time delta range(s) for each variation dimension. double[] timeDeltaRange = new double[2]; int[] positionsCount = new int[2]; int[][] variationDimForSpike = new int[2][Math.max(spikeCounts[0], spikeCounts[1])]; // [pre, post][spike index] // Set-up parameters for testing spike patterns with gradually altered spike times over one or // two dimensions. for (int d = 0; d < variationDimsCount; d++) { double baseRefSpikeTimeInitial = patterns[0][refSpikePreOrPost[d][0]][refSpikeIndexes[d][0]]; double relativeRefSpikeTimeInitial = patterns[0][refSpikePreOrPost[d][1]][refSpikeIndexes[d][1]]; double baseRefSpikeTimeFinal = patterns[d + 1][refSpikePreOrPost[d][0]][refSpikeIndexes[d][0]]; double relativeRefSpikeTimeFinal = patterns[d + 1][refSpikePreOrPost[d][1]][refSpikeIndexes[d][1]]; timeDeltaInitial[d] = relativeRefSpikeTimeInitial - baseRefSpikeTimeInitial; timeDeltaFinal[d] = relativeRefSpikeTimeFinal - baseRefSpikeTimeFinal; timeDeltaRange[d] = Math.abs(timeDeltaInitial[d] - timeDeltaFinal[d]); // From the initial and final spiking protocols we generate intermediate spiking protocols // by interpolation. // // Each position in between the initial and final protocol adjusts the time differential // between the base and // reference spikes by (1/timeResolution) seconds. positionsCount[d] = (int) Math.round(timeDeltaRange[d] * displayTimeResolution) + 1; // Determine which dimension, if any, a spikes timing varies over (and ensure that a spikes // timing only varies // over at most one dimension). // If the spikes time in variation dimension d is different to the initial spike time. for (int p = 0; p < 2; p++) { for (int si = 0; si < spikeCounts[p]; si++) { // If it also differs in another dimension. if (patterns[0][p][si] != patterns[d + 1][p][si]) { if (variationDimForSpike[p][si] != 0) { throw new IllegalArgumentException( "A spikes timing may vary at most over one variation dimension. " + (p == 0 ? "Pre" : "Post") + "-synaptic spike " + (si + 1) + " varies over two."); } variationDimForSpike[p][si] = d + 1; } } } } double[][] currentSpikeTimings = new double[2][]; // Current pre and post spiking patterns [pre, post][spike index] for (int p = 0; p < 2; p++) { currentSpikeTimings[p] = new double[spikeCounts[p]]; System.arraycopy(patterns[0][p], 0, currentSpikeTimings[p], 0, spikeCounts[p]); } // If we're testing spike patterns with gradually altered spike times over one dimension. if (variationDimsCount == 1) { // Arrays to record results. double[] time = new double[positionsCount[0]]; // The time delta in seconds [time delta index] // The change in synapse efficacy after all repetitions for each pattern [time delta index] double[] efficacyLog = new double[positionsCount[0]]; if (progressMonitor != null) { progressMonitor.setMaximum(positionsCount[0]); } for (int timeDeltaIndex = 0; timeDeltaIndex < positionsCount[0]; timeDeltaIndex++) { if (progressMonitor != null) { progressMonitor.setProgress(timeDeltaIndex); } double position = (double) timeDeltaIndex / (positionsCount[0] - 1); // Position in variation dimension 1 // Generate pre and post spike timing patterns for this position. for (int p = 0; p < 2; p++) { for (int si = 0; si < spikeCounts[p]; si++) { // If this spikes timing varies. int variationDim = variationDimForSpike[p][si]; if (variationDim != 0) { currentSpikeTimings[p][si] = position * patterns[0][p][si] + (1 - position) * patterns[variationDim][p][si]; } } } preConfig.spikeTimings = currentSpikeTimings[0]; postConfig.spikeTimings = currentSpikeTimings[1]; preConfig.fireChangeEvent(); postConfig.fireChangeEvent(); sim.reset(); sim.run(simSteps); time[timeDeltaIndex] = position * timeDeltaInitial[0] + (1 - position) * timeDeltaFinal[0]; efficacyLog[timeDeltaIndex] = synapse.getEfficacy(0); } results.setProperty("type", TYPE.STDP_1D); results.addResult("Efficacy", efficacyLog); results.addResult("Time delta", time); // We're testing spike patterns with gradually altered spike times over two dimensions. } else { // The change in synapse efficacy after all repetitions for each pattern // [time delta for var dim 1, time delta for var dim 2, synapse efficacy][result index] double[][] efficacyLog = new double[3][(positionsCount[0]) * (positionsCount[1])]; if (progressMonitor != null) { progressMonitor.setMaximum(efficacyLog[0].length); } double[] position = new double[2]; // Position in variation dimensions 1 and 2 for (int timeDeltaIndex1 = 0, resultIndex = 0; timeDeltaIndex1 < positionsCount[0]; timeDeltaIndex1++) { position[0] = (double) timeDeltaIndex1 / (positionsCount[0] - 1); for (int timeDeltaIndex2 = 0; timeDeltaIndex2 < positionsCount[1]; timeDeltaIndex2++, resultIndex++) { if (progressMonitor != null) { progressMonitor.setProgress(resultIndex); } position[1] = (double) timeDeltaIndex2 / (positionsCount[1] - 1); // Generate pre and post spike timing patterns for this position. for (int p = 0; p < 2; p++) { for (int si = 0; si < spikeCounts[p]; si++) { // If this spikes timing varies. int variationDim = variationDimForSpike[p][si]; if (variationDim != 0) { currentSpikeTimings[p][si] = (1 - position[variationDim - 1]) * patterns[0][p][si] + position[variationDim - 1] * patterns[variationDim][p][si]; } } } preConfig.spikeTimings = currentSpikeTimings[0]; postConfig.spikeTimings = currentSpikeTimings[1]; preConfig.fireChangeEvent(); postConfig.fireChangeEvent(); sim.reset(); sim.run(simSteps); efficacyLog[0][resultIndex] = (1 - position[0]) * timeDeltaInitial[0] + position[0] * timeDeltaFinal[0]; efficacyLog[1][resultIndex] = (1 - position[1]) * timeDeltaInitial[1] + position[1] * timeDeltaFinal[1]; efficacyLog[2][resultIndex] = synapse.getEfficacy(0); } } results.setProperty("type", TYPE.STDP_2D); results.addResult("Time delta 1", "Time delta 2", "Efficacy", efficacyLog); } } return results; }