Ejemplo n.º 1
0
  /**
   * Test the behaviour of a synapse model.
   *
   * @param network A neural network containing two neurons and a single synapse to be tested. The
   *     neurons at index 0 and 1 should be the pre- and post-synaptic neurons respectively, and are
   *     typically configured to produce a fixed firing pattern (though any neuron type and firing
   *     pattern are permitted).
   * @param simSteps The number of steps to run the simulation for.
   * @param logSpikesAndStateVariables Whether to record pre- and post-synaptic spikes and any state
   *     variables exposed by the synapse model in the test results.
   * @param simStepsNoSpikes The number of steps to run the simulation for with no spiking after the
   *     normal spike test. This is useful for testing models which change the efficacy in the
   *     absence of spikes.
   * @return For a single spike protocol, a TestResults object with type {@link TYPE#STDP}
   *     consisting of series labelled "Time" and "Efficacy" and if logSpikesAndStateVariables ==
   *     true then also "Pre-synaptic spikes", "Post-synaptic spikes" and any state variables
   *     exposed by the synapse model.
   */
  public static TestResults singleTest(
      NeuralNetwork network,
      long simSteps,
      boolean logSpikesAndStateVariables,
      long simStepsNoSpikes) {
    if (network.getNeurons().getSize() != 2 || network.getSynapses().getSize() != 1) {
      throw new IllegalArgumentException(
          "The neural network must contain at least 2 neurons and 1 synapse.");
    }

    simStepsNoSpikes = Math.max(0, simStepsNoSpikes);

    int displayTimeResolution = Math.min(1000, network.getTimeResolution());

    int logStepCount = network.getTimeResolution() / displayTimeResolution;
    int logSize = (int) ((simSteps + simStepsNoSpikes) / logStepCount);

    double[] timeLog = new double[logSize];
    double[] efficacyLog = new double[logSize];
    double[][] prePostLogs = null, traceLogs = null;
    double[] stateVars;
    if (logSpikesAndStateVariables) {
      prePostLogs = new double[2][logSize];
      traceLogs = new double[network.getSynapses().getStateVariableNames().length][logSize];
    }

    int logIndex = 0;
    long step = 0;
    for (; step < simSteps; step++) {
      double time = network.getTime();
      network.step();

      if (step % logStepCount == 0) {
        timeLog[logIndex] = time;
        efficacyLog[logIndex] = network.getSynapses().getEfficacy(0);
        // If we're only testing a few repetitions, include some extra
        // data.
        if (logSpikesAndStateVariables) {
          prePostLogs[0][logIndex] = network.getNeurons().getOutput(0);
          prePostLogs[1][logIndex] = network.getNeurons().getOutput(1);
          stateVars = network.getSynapses().getStateVariableValues(0);
          for (int v = 0; v < stateVars.length; v++) {
            traceLogs[v][logIndex] = stateVars[v];
          }
        }
        logIndex++;
      }
    }

    if (simStepsNoSpikes > 0) {
      FixedProtocolNeuronConfiguration config =
          new FixedProtocolNeuronConfiguration(100, new double[] {});
      network.getNeurons().setConfiguration(0, config);
      network.getNeurons().setComponentConfiguration(0, 0);
      network.getNeurons().setComponentConfiguration(1, 0);

      for (; step < simSteps + simStepsNoSpikes; step++) {
        double time = network.getTime();
        network.step();

        if (step % logStepCount == 0) {
          timeLog[logIndex] = time;
          efficacyLog[logIndex] = network.getSynapses().getEfficacy(0);
          // If we're only testing a few repetitions, include some extra data.
          if (logSpikesAndStateVariables) {
            prePostLogs[0][logIndex] = network.getNeurons().getOutput(0);
            prePostLogs[1][logIndex] = network.getNeurons().getOutput(1);
            stateVars = network.getSynapses().getStateVariableValues(0);
            for (int v = 0; v < stateVars.length; v++) {
              traceLogs[v][logIndex] = stateVars[v];
            }
          }
          logIndex++;
        }
      }
    }

    TestResults results = new TestResults();
    results.setProperty("type", TYPE.STDP);
    results.addResult("Efficacy", efficacyLog);
    results.addResult("Time", timeLog);
    if (logSpikesAndStateVariables) {
      results.addResult("Pre-synaptic spikes", prePostLogs[0]);
      results.addResult("Post-synaptic spikes", prePostLogs[1]);
      String[] stateVariableNames = network.getSynapses().getStateVariableNames();
      for (int v = 0; v < stateVariableNames.length; v++) {
        results.addResult(stateVariableNames[v], traceLogs[v]);
      }
    }

    return results;
  }
Ejemplo n.º 2
0
  /**
   * Test a synapse on the specified spiking protocol or a series of spiking protocols derived from
   * initial and final protocols by interpolation over one or two dimensions.
   *
   * @param synapse The SynapseCollection containing the synapse to test (the first synapse is
   *     used).
   * @param timeResolution The time resolution to use in the simulation, see {@link
   *     com.ojcoleman.bain.NeuralNetwork}
   * @param period The period of the spike pattern in seconds.
   * @param repetitions The number of times to apply the spike pattern.
   * @param patterns Array containing spike patterns, in the form [initial, dim 1, dim 2][pre,
   *     post][spike number] = spike time. The [spike number] array contains the times (s) of each
   *     spike, relative to the beginning of the pattern. See {@link
   *     com.ojcoleman.bain.neuron.spiking.FixedProtocolNeuronCollection}.
   * @param refSpikeIndexes Array specifying indexes of the two spikes to use as timing variation
   *     references for each variation dimension, in the form [dim 1, dim 2][reference spike,
   *     relative spike] = spike index.
   * @param refSpikePreOrPost Array specifying whether the timing variation reference spikes
   *     specified by refSpikeIndexes belong to the pre- or post-synaptic neurons, in the form [dim
   *     1, dim 2][base spike, relative spike] = Constants.PRE or Constants.POST.
   * @param logSpikesAndStateVariables Whether to record pre- and post-synaptic spikes and any state
   *     variables exposed by the synapse model in the test results.
   * @param progressMonitor If not null, this will be updated with the current progress.
   * @return For a single spike protocol, a TestResults object with type {@link TYPE#STDP}
   *     consisting of series labelled "Time" and "Efficacy" and if logSpikesAndStateVariables ==
   *     true then also "Pre-synaptic spikes", "Post-synaptic spikes" and any state variables
   *     exposed by the synapse model. For a protocol varied over one dimension, a TestResults
   *     object with type {@link TYPE#STDP_1D} consisting of series labelled "Time delta" and
   *     "Efficacy". For a protocol varied over two dimensions, a TestResults object with type
   *     {@link TYPE#STDP_2D} consisting of series labelled "Time delta 1", "Time delta 2" and
   *     "Efficacy".
   */
  public static TestResults testPattern(
      SynapseCollection<? extends ComponentConfiguration> synapse,
      int timeResolution,
      double period,
      int repetitions,
      double[][][] patterns,
      int[][] refSpikeIndexes,
      int[][] refSpikePreOrPost,
      boolean logSpikesAndStateVariables,
      ProgressMonitor progressMonitor)
      throws IllegalArgumentException {
    int variationDimsCount =
        patterns.length - 1; // Number of dimensions over which spike timing patterns vary.
    if (variationDimsCount > 2) {
      throw new IllegalArgumentException(
          "The number of variation dimensions may not exceed 2 (patterns.length must be <= 3)");
    }

    if (progressMonitor != null) {
      progressMonitor.setMinimum(0);
    }

    TestResults results = new TestResults();

    FixedProtocolNeuronCollection neurons = new FixedProtocolNeuronCollection(2);
    FixedProtocolNeuronConfiguration preConfig =
        new FixedProtocolNeuronConfiguration(period, patterns[0][0]);
    neurons.addConfiguration(preConfig);
    neurons.setComponentConfiguration(0, 0);
    FixedProtocolNeuronConfiguration postConfig =
        new FixedProtocolNeuronConfiguration(period, patterns[0][1]);
    neurons.addConfiguration(postConfig);
    neurons.setComponentConfiguration(1, 1);

    synapse.setPreNeuron(0, 0);
    synapse.setPostNeuron(0, 1);

    NeuralNetwork sim = new NeuralNetwork(timeResolution, neurons, synapse);

    int simSteps = (int) Math.round(period * repetitions * timeResolution);

    int displayTimeResolution = Math.min(1000, timeResolution);

    results.setProperty("simulation time resolution", timeResolution);
    results.setProperty("display time resolution", displayTimeResolution);

    // long startTime = System.currentTimeMillis();

    // If we're just testing a single spike pattern. // Handle separately as logging is quite
    // different from testing spike
    // patterns with gradually altered spike times.
    if (variationDimsCount == 0) {
      results = singleTest(sim, simSteps, logSpikesAndStateVariables, 0);
    } else { // We're testing spike patterns with gradually altered spike times over one or two
      // dimensions.

      int[] spikeCounts = {patterns[0][0].length, patterns[0][1].length};
      // The initial and final time deltas (s), given base and relative spike times in initial and
      // final spike patterns,
      // for each variation dimension.
      double[] timeDeltaInitial = new double[2], timeDeltaFinal = new double[2];
      // The time delta range(s) for each variation dimension.
      double[] timeDeltaRange = new double[2];
      int[] positionsCount = new int[2];
      int[][] variationDimForSpike =
          new int[2][Math.max(spikeCounts[0], spikeCounts[1])]; // [pre, post][spike index]

      // Set-up parameters for testing spike patterns with gradually altered spike times over one or
      // two dimensions.
      for (int d = 0; d < variationDimsCount; d++) {
        double baseRefSpikeTimeInitial =
            patterns[0][refSpikePreOrPost[d][0]][refSpikeIndexes[d][0]];
        double relativeRefSpikeTimeInitial =
            patterns[0][refSpikePreOrPost[d][1]][refSpikeIndexes[d][1]];
        double baseRefSpikeTimeFinal =
            patterns[d + 1][refSpikePreOrPost[d][0]][refSpikeIndexes[d][0]];
        double relativeRefSpikeTimeFinal =
            patterns[d + 1][refSpikePreOrPost[d][1]][refSpikeIndexes[d][1]];

        timeDeltaInitial[d] = relativeRefSpikeTimeInitial - baseRefSpikeTimeInitial;
        timeDeltaFinal[d] = relativeRefSpikeTimeFinal - baseRefSpikeTimeFinal;
        timeDeltaRange[d] = Math.abs(timeDeltaInitial[d] - timeDeltaFinal[d]);

        // From the initial and final spiking protocols we generate intermediate spiking protocols
        // by interpolation. //
        // Each position in between the initial and final protocol adjusts the time differential
        // between the base and
        // reference spikes by (1/timeResolution) seconds.
        positionsCount[d] = (int) Math.round(timeDeltaRange[d] * displayTimeResolution) + 1;

        // Determine which dimension, if any, a spikes timing varies over (and ensure that a spikes
        // timing only varies
        // over at most one dimension).
        // If the spikes time in variation dimension d is different to the initial spike time.
        for (int p = 0; p < 2; p++) {
          for (int si = 0; si < spikeCounts[p]; si++) {
            // If it also differs in another dimension.
            if (patterns[0][p][si] != patterns[d + 1][p][si]) {
              if (variationDimForSpike[p][si] != 0) {
                throw new IllegalArgumentException(
                    "A spikes timing may vary at most over one variation dimension. "
                        + (p == 0 ? "Pre" : "Post")
                        + "-synaptic spike "
                        + (si + 1)
                        + " varies over two.");
              }
              variationDimForSpike[p][si] = d + 1;
            }
          }
        }
      }

      double[][] currentSpikeTimings =
          new double[2][]; // Current pre and post spiking patterns [pre, post][spike index]
      for (int p = 0; p < 2; p++) {
        currentSpikeTimings[p] = new double[spikeCounts[p]];
        System.arraycopy(patterns[0][p], 0, currentSpikeTimings[p], 0, spikeCounts[p]);
      }

      // If we're testing spike patterns with gradually altered spike times over one dimension.
      if (variationDimsCount == 1) {
        // Arrays to record results.
        double[] time =
            new double[positionsCount[0]]; // The time delta in seconds [time delta index]
        // The change in synapse efficacy after all repetitions for each pattern [time delta index]
        double[] efficacyLog = new double[positionsCount[0]];

        if (progressMonitor != null) {
          progressMonitor.setMaximum(positionsCount[0]);
        }

        for (int timeDeltaIndex = 0; timeDeltaIndex < positionsCount[0]; timeDeltaIndex++) {
          if (progressMonitor != null) {
            progressMonitor.setProgress(timeDeltaIndex);
          }

          double position =
              (double) timeDeltaIndex
                  / (positionsCount[0] - 1); // Position in variation dimension 1

          // Generate pre and post spike timing patterns for this position.
          for (int p = 0; p < 2; p++) {
            for (int si = 0; si < spikeCounts[p]; si++) { // If this spikes timing varies.
              int variationDim = variationDimForSpike[p][si];
              if (variationDim != 0) {
                currentSpikeTimings[p][si] =
                    position * patterns[0][p][si] + (1 - position) * patterns[variationDim][p][si];
              }
            }
          }

          preConfig.spikeTimings = currentSpikeTimings[0];
          postConfig.spikeTimings = currentSpikeTimings[1];
          preConfig.fireChangeEvent();
          postConfig.fireChangeEvent();

          sim.reset();
          sim.run(simSteps);

          time[timeDeltaIndex] =
              position * timeDeltaInitial[0] + (1 - position) * timeDeltaFinal[0];
          efficacyLog[timeDeltaIndex] = synapse.getEfficacy(0);
        }

        results.setProperty("type", TYPE.STDP_1D);
        results.addResult("Efficacy", efficacyLog);
        results.addResult("Time delta", time);

        // We're testing spike patterns with gradually altered spike times over two dimensions.
      } else {
        // The change in synapse efficacy after all repetitions for each pattern
        // [time delta for var dim 1, time delta for var dim 2, synapse efficacy][result index]
        double[][] efficacyLog = new double[3][(positionsCount[0]) * (positionsCount[1])];

        if (progressMonitor != null) {
          progressMonitor.setMaximum(efficacyLog[0].length);
        }

        double[] position = new double[2]; // Position in variation dimensions 1 and 2
        for (int timeDeltaIndex1 = 0, resultIndex = 0;
            timeDeltaIndex1 < positionsCount[0];
            timeDeltaIndex1++) {
          position[0] = (double) timeDeltaIndex1 / (positionsCount[0] - 1);

          for (int timeDeltaIndex2 = 0;
              timeDeltaIndex2 < positionsCount[1];
              timeDeltaIndex2++, resultIndex++) {
            if (progressMonitor != null) {
              progressMonitor.setProgress(resultIndex);
            }

            position[1] = (double) timeDeltaIndex2 / (positionsCount[1] - 1);

            // Generate pre and post spike timing patterns for this position.
            for (int p = 0; p < 2; p++) {
              for (int si = 0; si < spikeCounts[p]; si++) { // If this spikes timing varies.
                int variationDim = variationDimForSpike[p][si];
                if (variationDim != 0) {
                  currentSpikeTimings[p][si] =
                      (1 - position[variationDim - 1]) * patterns[0][p][si]
                          + position[variationDim - 1] * patterns[variationDim][p][si];
                }
              }
            }

            preConfig.spikeTimings = currentSpikeTimings[0];
            postConfig.spikeTimings = currentSpikeTimings[1];
            preConfig.fireChangeEvent();
            postConfig.fireChangeEvent();

            sim.reset();
            sim.run(simSteps);

            efficacyLog[0][resultIndex] =
                (1 - position[0]) * timeDeltaInitial[0] + position[0] * timeDeltaFinal[0];
            efficacyLog[1][resultIndex] =
                (1 - position[1]) * timeDeltaInitial[1] + position[1] * timeDeltaFinal[1];
            efficacyLog[2][resultIndex] = synapse.getEfficacy(0);
          }
        }

        results.setProperty("type", TYPE.STDP_2D);
        results.addResult("Time delta 1", "Time delta 2", "Efficacy", efficacyLog);
      }
    }

    return results;
  }