Introduction

How can we train spiking neural networks to achieve brain-like performance in machine learning tasks? The resounding success and pervasive use of the backpropagation algorithm in deep learning suggests an analogous approach. This algorithm computes the gradient of the neural network parameters with respect to a loss function that measures the network’s performance in a given task. The parameters of the network are iteratively updated using the locally optimal direction given by the gradient.

Spiking neural networks have been referred to as the third generation of neural networks1, superseding artificial neural networks as commonly used in deep learning and hold the promise for efficient and robust processing of event-based spatio-temporal data as found in biological systems. However, spiking models are not widely used in machine learning applications. At the same time, the development of spiking neuromorphic hardware receives increasing attention2 and learning in spiking neural networks is an active research subject, with a wide variety of proposed algorithms. A notorious issue in spiking neurons is the hard spiking threshold that does not permit a straight-forward application of differential calculus to compute gradients. Although exact gradients have been derived for special cases, this issue is commonly side-stepped by using smoothed or stochastic neuron models or by replacing the hard threshold function using a surrogate function, leading to the computation of surrogate gradients3.

In contrast, this work provides an algorithm, EventProp, to compute the exact gradient for an arbitrary loss function defined using the state variables (spike times and membrane potentials) of a general recurrent spiking neural network composed of leaky integrate-and-fire neurons with hard thresholds. Since feed-forward architectures correspond to recurrent neural networks with block-diagonal weight matrices and convolutions can be represented as sparse linear transformations, deep feed-forward networks and convolutional networks are included as special cases.

The leaky integrate-and-fire neuron model describes a hybrid dynamical system that combines continuous dynamics between spikes with discontinuous state variable transitions at spike times. The computation of partial derivatives for hybrid dynamical systems is an established topic in optimal control theory4,5. In hybrid systems, the time-dependent partial derivative \(\frac{\partial x}{\partial p}(t)\) of a state variable x with respect to a parameter p generally experiences jumps at the points of discontinuity (see Fig. 1A,B). The relation between the partial derivatives before and after a given discontinuity was first studied in the 1960s6,7. A more general theoretical framework was developed thirty years later8, providing existence and uniqueness theorems for the partial derivative trajectories \(\frac{\partial x}{\partial p}(t)\) of hybrid systems.

Figure 1
figure 1

We derive the precise analogue to backpropagation for spiking neural networks by applying the adjoint method together with the jump conditions for partial derivatives at state discontinuities, yielding exact gradients with respect to loss functions based on membrane potentials or spike times. (A, B) Dynamical systems with parameter-dependent discontinuous state transitions typically have discontinuous partial derivatives of state variables with respect to system parameters4, as is the case for the two examples shown here. Both examples model dynamics occurring on short timescales, namely inelastic reflection and the neuronal spike mechanism, using an instantaneous state transition. We denote quantities evaluated before and after a given transition by − and \(+\). In A, a bouncing ball starts at height \(y_0>0\) and is described by \(\ddot{y}=-g\) with gravitational acceleration g. It is inelastically reflected as \(\dot{y}^+=-0.8\dot{y}^-\) as soon as \(y^-=0\) holds, causing the partial derivative with respect to \(y_0\) to jump as \(\frac{\partial y^+}{\partial y_0}=-0.8\frac{\partial y^-}{\partial y_0}\) (see first methods subsection). In B, a leaky integrate-and-fire neuron described by the system given in Table 1 with initial conditions \(I(0)=w\), \(V(0)=0\) resets its membrane potential as \(V^+=0\) when \(V^-=\vartheta \) holds, causing the partial derivative to jump as \(\frac{\partial V^+}{\partial w}=\left( \frac{\vartheta }{\tau _\text {mem}\dot{V}^-}+1\right) \frac{\partial V^-}{\partial w}\) (see methods for the full derivation). (C) Applying the adjoint method with partial derivative jumps to a network of leaky integrate-and-fire neurons (Table 1) yields the adjoint system (Table 2) that backpropagates errors in time. EventProp is an algorithm (Algorithm 1) returning the gradient of a loss function with respect to synaptic weights by computing this adjoint system. The forward pass computes the state variables V(t), I(t) and stores spike times \(t^{\text {post}}_{}\) and each firing neuron’s synaptic current. EventProp then performs the backward pass by computing the adjoint system backwards in time using event-based error backpropagation and gradient accumulation: each time a spike was transferred across a given synaptic weight in the forward pass, EventProp backpropagates the error signal represented by the adjoint variables \(\lambda _V(t^{\text {post}}_{})\), \(\lambda _I(t^{\text {post}}_{})\) of the post-synaptic (target) neuron and updates the corresponding component of the gradient by accumulating \(\lambda _I(t^{\text {post}}_{})\), finally yielding sums as given in the figure

Discontinuous state transitions in hybrid systems occur when a transition condition is fulfilled (e.g., a bouncing ball hits the floor or a neuron reaches its spiking threshold). The existence of well-defined partial derivative jumps at the state transition times depends on the local applicability of the implicit function theorem to the transition condition, requiring that the event time depends on the parameters in a differentiable fashion. In the case considered here, a spiking neural network composed of leaky integrate-and-fire neurons that is parameterized by synaptic weights, this is fulfilled up to the null set in weight space that contains the locally defined hypersurfaces where spikes are added or removed. At these critical points, the derivative of the time of the (dis-)appearing spike with respect to a given active synaptic weight diverges. This implies that both the spike times and an integral of a smooth loss function over the membrane potential are differentiable almost everywhere, up to the null set of critical points in weight space.

Having established the jumps of partial derivatives in the leaky integrate-and-fire neuron model, the relevant question is how to compute the gradient of a loss function for spiking neural networks, preferably with the computational efficiency afforded by the backpropagation algorithm and retaining any potential advantages of event-based communication. Backpropagation in discrete-time artificial neural networks can be derived as a special case of the adjoint method9, with the adjoint variables (Lagrange multipliers) \(\lambda _t\) at each time step t corresponding to the intermediate variables computed in the backpropagation algorithm. Applying the adjoint method to continuous-time dynamical systems yields time-dependent adjoint variables \(\lambda (t)\) (see methods section) and their computation in reverse time is analogous to the backpropagation of errors in discrete-time artificial neural networks. The adjoint method can be applied to hybrid systems by using the proper partial derivative jumps that generally cause jumps in the adjoint variables10.

We combine the partial derivative jumps of the leaky integrate-and-fire neuron with the adjoint method in order to derive the EventProp algorithm (Algorithm 1) that is the analogue to backpropagation for spiking neural networks (Fig. 1C). Since EventProp backpropagates errors at spike times, the algorithm computes gradients using an event-based communication scheme and is amenable to neuromorphic implementation. By requiring the storage of state variables only at spike times, it provides favorable memory requirements compared to approaches that require the full forward state trajectory to be retained for the backward pass. For example, surrogate gradient approaches operating on a discrete time grid require storing state variables at every time step for the backward pass. More generally, the fact that backpropagation in discrete-time artificial neural networks requires storing activations at every time step causes a memory bottleneck and is a major concern in training very deep architectures11,12,13.

EventProp does not prescribe a specific numerical scheme to compute state variables and spike times but since the backward pass corresponds to the computation of a spiking network with pre-determined spike times, the computational complexity of the backward pass generally corresponds to that of the forward pass. While surrogate gradient approaches on a discrete time grid typically require the calculation of dense matrix-vector products at every time step in the backward pass (all neurons backpropagate error signals at every time step), EventProp only requires computing vector-vector products at spike events (only the firing neuron receives backpropagated errors at a given spike time). In this way, EventProp leverages the sparseness of spike-based communication for both the forward and backward pass.

We demonstrate the training of spiking neural networks with a single hidden layer using EventProp and the Yin-Yang and MNIST datasets, resulting in competitive classification performance.

Previous work

For a comprehensive survey of gradient-based approaches to learning in spiking neural networks, we refer the reader to review articles which discuss learning in deep spiking networks2,14,15, discuss learning along with the history and future of neuromorphic computing2 or focus on the surrogate gradient approach3. Surrogate gradients use smooth activation functions for the purposes of backpropagation and have been used to train spiking networks in a variety of settings16,17,18,19. This approach is typically derived by considering the Euler discretization of a spiking neural network where the Heaviside step function is used to couple neurons across discrete time steps. The non-differentiable Heaviside step function is then replaced by a smooth function in the backward pass.

Apart from surrogate gradients, several publications provide exact gradients for first-spike-time based loss functions and leaky integrate-and-fire neurons: a seminal article20 provides the gradient for at most one spike per layer and this result was subsequently generalized to an arbitrary number of spikes as well as recurrent connectivity21,22. While these publications provide recursive relations for the gradient that can be implicitly computed using backpropagation, we explicitly provide the dynamical system that implements backpropagation through time and show that it represents an adjoint spiking network which transmits errors at spike times, allowing for an event-based computation of the gradient. In addition, we also consider voltage-dependent loss functions and our methodology can be applied to neuron models without analytic expressions for the post-synaptic potential kernels.

The applicability of methods from optimal control theory (i.e., partial derivative jumps and the adjoint method) to compute exact gradients in hard-threshold spiking neural networks was recognized in a series of publications23,24,25. In contrast to this work, these articles consider a neuron model with a two-sided threshold (including negative threshold crossings), rely on the existence of analytic expressions for the post-synaptic potential kernels, provide specialized algorithms tailored to specific loss functions and consider minimalistic regression tasks.

The chronotron26 uses a gradient-based learning rule based on the Victor-Purpura metric which enables a single leaky integrate-and-fire neuron to learn a target spike train. Our work, as well as the works mentioned above which derive exact gradients, applies the implicit function theorem to differentiate spike times with respect to synaptic weights. A different approach is to consider ratios of the neuronal time constants where analytic expressions for first spike times can be given and to derive the corresponding gradients, as done in27,28,29,30. Our work encompasses the contained methods to compute the gradient as special cases.

The seminal Tempotron model uses gradient descent to adjust the sub-threshold voltage maximum in a single neuron31 and has recently been generalized to the spike threshold surface formalism32 that uses the exact gradient of the critical thresholds \(\vartheta ^*_k\) at which a leaky integrate-and-fire neuron transitions from emitting k to \(k-1\) spikes; computing this gradient is not considered in this work. The adjoint method was recently used to optimize neural ordinary differential equations33 and neural jump stochastic differential equations34 as well as to derive the gradient for a smoothed spiking neuron model without reset35.

We first define the used spiking neuron model and then proceed to state our main results.

Leaky integrate-and-fire neural network model

We define a network of N leaky integrate-and-fire neurons with arbitrary (up to self-connections) recurrent connectivity (Table 1). We set the leak potential to zero and choose parameter-independent initial conditions. Note that the Spike-Response Model (SRM)36 with double-exponential or \(\alpha \)-shaped PSPs is generally an integral expression of the model given in Table 1 with corresponding time constants.

Table 1 The leaky integrate-and-fire spiking neural network model. Inbetween spikes, the vectors of membrane potentials V and synaptic currents I evolve according to the free dynamics. When some neuron \(n\in [1..N]\) crosses the threshold \(\vartheta \), the transition condition is fulfilled, causing a spike. This leads to a reset of the membrane potential as well as post-synaptic current jumps. \(W\in \mathbb {R}^{N\times N}\) is the weight matrix with zero diagonal and \(e_n\in \mathbb {R}^N\) is the unit vector with a 1 at index n and 0 at all other indices. We use − and \(+\) to denote quantities before and after a given spike

Gradient via backpropagation

Consider smooth loss functions \(l_V(V, t)\), \(l_{\mathrm {p}}(t^{\text {post}}_{})\) that depend on the membrane potentials V, time t and the set of post-synaptic spike times \(t^{\text {post}}_{}\). The total loss is given by

$$\begin{aligned} \mathcal {L} = l_{\mathrm {p}}(t^{\text {post}}_{})+\int _0^T l_V(V(t), t){\mathrm {d}}t. \end{aligned}$$
(1)

Our main result is that the derivative of the total loss with respect to a specific weight \(w_{ji}=(W)_{ji}\) that connects pre-synaptic neuron i (the firing neuron) to post-synaptic neuron j (the receiving neuron) is given by a sum over the spikes caused by i,

$$\begin{aligned} \frac{{\mathrm {d}}\mathcal {L}}{{\mathrm {d}}w_{ji}} = - \tau _\text {syn}\sum _{\text {spikes from }i} (\lambda _I)_j, \end{aligned}$$
(2)

where \(\lambda _I\) is the adjoint variable (Lagrange multiplier) corresponding to the synaptic current I. Equation (2) therefore samples the post-synaptic neuron’s adjoint variable \((\lambda _I)_j\) at the spike times caused by neuron i.

After the neuron dynamics given by Table 1 have been computed from \(t=0\) to \(t=T\), the adjoint state variable \(\lambda _I\) is computed in reverse time (i.e., from \(t=T\) to \(t=0\)) as the solution of the system of adjoint equations defined in Table 2. The dynamical system defined by Table 2 is the adjoint spiking network to the leaky integrate-and-fire network (Table 1) which backpropagates error signals at the spike times \(t^{\text {post}}_{}\).

Table 2 The adjoint spiking network to Table 1 that computes the adjoint variable \(\lambda _I\) needed for the gradient [Eq. (2)]. The adjoint variables are computed in reverse time (i.e., from \(t=T\) to \(t=0\)) with \('=-\frac{{\mathrm {d}}}{{\mathrm {d}}t}\) denoting the reverse time derivative. \((\lambda _V^-)_{n(k)} \) experiences jumps at the spikes times \(t^{\text {post}}_{k}\), where n(k) is the index of the neuron that caused the kth spike. Computing this system amounts to the backpropagation of errors in time. The initial conditions are \(\lambda _V(T)=\lambda _I(T)=0\) and we provide \(\lambda _V^-\) in terms of \(\lambda _V^+\) because the computation happens in reverse time

Equation (2) and Table 2 suggest a simple algorithm, EventProp, to compute the gradient (Algorithm 1). Notably, if the loss is voltage-independent (i.e., \(l_V=0\)), the backward pass of the algorithm requires only the spike times \(t^{\text {post}}_{}\) and the synaptic current of the firing neurons at their respective firing times to be retained from the forward pass. The membrane potential at spike times is fixed to the threshold \(\vartheta \) and therefore implicitly retained; the synaptic current therefore determines the temporal derivative of the membrane potential at the spike time, \(\dot{V}^-\), and needs to be stored for the backward pass. The memory requirement of the algorithm scales as \(\mathcal {O}(S)\), where S is the number of post-synaptic spikes in the network. A feed-forward architecture corresponds to a block matrix W with each block being a strictly triangular matrix that connects two given layers. In that case, the forward and backward pass can be computed in a layer-wise fashion.

In case of a voltage-dependent loss \(l_V\ne 0\), the algorithm has to store the non-zero components of \(\frac{\partial l_V}{\partial V}\) along the forward trajectory. The loss \(l_V\) may depend on the voltage at a discrete time \(t_i\) using the Dirac delta, \(l_V(V(t), t) = V(t)\delta (t_i-t)\), causing a jump of \(\lambda _V\) of magnitude \(\tau _\text {mem}^{-1}\) at time \(t_i\). Note that in many practical scenarios as found in deep learning, the loss \(l_V\) depends only on the state of a constant number of neurons, irrespective of network size. If \(l_V\) depends on the voltage of non-firing readout neurons, we have \(l_V^+ = l_V^-\) and the corresponding term in the jump given in Table 2 vanishes.

If \(l_V\) is either zero or depends only on voltages at discrete points in time, EventProp can be computed in a purely event-based manner. Figure 2 illustrates how EventProp computes the gradient of a spike time based loss function for two leaky integrate-and-fire neurons where one neuron receives Poisson spike trains via 100 synapses and is connected to the other neuron via a single feed-forward weight w.

figure a
Figure 2
figure 2

Illustration of EventProp-based gradient calculation in two leaky integrate-and-fire neurons connected with weight w and a spike-time dependent loss \(\mathcal {L}\). The forward pass (B, C) computes the spike times for both neurons and the backward pass (DG) backpropagates errors at spike times, yielding the gradient as given in Eq. (2). (A) The upper neuron receives 100 independent Poisson spike trains with frequency \({200}\hbox { Hz}\) across randomly initialized weights and is connected to the lower neuron via a single weight w. The loss \(\mathcal {L}\) is a sum of the spike times of the lower neuron. (B, C) Membrane potential of upper and lower neuron. Spike times of the upper neuron are indicated using arrows. (D, E) Adjoint variable \(\lambda _I\) of upper and lower neuron. The lower neuron backpropagates its error signal \(\lambda _V-\lambda _I\) at the upper neuron’s spike times (indicated by arrows). (F, G) Accumulated gradient for one of the 100 input weights of the upper neuron and the weight w connecting the upper and lower neuron. EventProp computes the adjoint variables from \(t=T\) to \(t=0\) and accumulates the gradients by sampling \(-\tau _\text {syn}\lambda _I\) when spikes are transmitted across the respective weight. The gradients computed in this way match the gradients computed via central differences (dashed lines) up to a relative deviation of less than \(10^{-7}\)

Simulation results

We demonstrate learning using EventProp using a custom event-based simulator and the Yin-Yang37 and MNIST38 datasets. In both cases, we use a single hidden layer and spike latency encoding of the input data. The Yin-Yang dataset is classified using the time to first spike of a layer of readout neurons while the MNIST dataset is classified using the voltage maxima of a layer of non-firing readout neurons. The simulator computes gradients using EventProp as described in Algorithm 1; specifically, it uses an event queue and root-bracketing to compute post-synaptic spike times in the forward pass (using exact integration of the membrane potential39) and backpropagates errors by attaching error signals to spikes in the backward pass and using reverse traversal of the event queue. We optimized synaptic weights using the calculated gradients via the Adam optimizer40, without clipping gradients.

By initializing synaptic weights such that the network started in a non-quiescent state, we found that no explicit regularization of firing rates was needed to obtain the reported results in both cases. Hyperparameters were optimized using Gaussian process optimization41 and manual tuning using the validation set of the respective dataset. The resulting parameters (see Table 3) were then evaluated using the test set.

Table 3 Simulation parameters used for the results described in the main text

Yin-Yang dataset

Figure 3
figure 3

We used EventProp and a time-to-first-spike loss function to train a two-layer leaky integrate-and-fire network on the Yin-Yang dataset. (A) Illustration of the two-dimensional training dataset. The three different classes are shown in red, green and blue. This dataset was encoded using spike time latencies (see D). (B, C) Training results in terms of test error and loss averaged over 10 different random seeds (individual traces shown as grey lines). (D) Data points (xy) were transformed into \((x, 1-x, y, 1-y)\) and encoded using spike time latencies. We added a fixed spike at time \(t_\text {bias}\). (E) Spike time latencies \(\Delta t\) of the three output neurons (encoding the blue, red or green class) after training, for all samples in the test set and a specific random seed. Latencies are relative to the first spike among the three neurons and given in units of \(t_{\mathrm {max}}\). A latency of zero (bright yellow dots) implies that the corresponding neuron fired the first spike, determining the class assignment. Missing spikes are denoted using green crosses

The Yin-Yang dataset37 is a two-dimensional non-linearly separable dataset, with a shallow classifier achieving around \(64\%\) accuracy, and it therefore requires a hidden layer and backpropagation of errors for high classification accuracy. Consider that in contrast, the MNIST dataset can be classified using a linear classifier with at least \(88\%\) accuracy38.

Each two-dimensional data point of the dataset (xy) was transformed into four dimensions as \((x, 1-x, y, 1-y)\) and encoded using spike latencies in the interval \([0, t_\text {max}]\) (see Fig. 3D). We added a fixed bias spike at time \(t_\text {bias}\) for a total of five input spikes per data point. The resulting spike patterns were used as input to a two-layer network composed of leaky integrate-and-fire neurons. The output layer consisted of three neurons that each encoded one of the three classes, with each data point being assigned the class of the neuron that fired the earliest spike.

In analogy to27, we used a cross-entropy loss defined using the first output spike times per neuron,

$$\begin{aligned} \mathcal {L} = -\frac{1}{N_{{\mathrm {batch}}}}\left[ \sum _{i=1}^{N_{{\mathrm {batch}}}}\log \left[ \frac{\exp \left( -t^{\text {post}}_{i,l(i)}/\tau _0\right) }{\sum _{k=1}^3 \exp \left( -t^{\text {post}}_{i,k}/\tau _0\right) }\right] + \alpha \left[ \exp \left( \frac{t^{\text {post}}_{i,l(i)}}{\tau _1}\right) - 1\right] \right] , \end{aligned}$$
(3)

where \(t^{\text {post}}_{i,k}\) is the first spike time of neuron k for the ith sample, l(i) is the index of the correct label for the ith sample, \(N_{{\mathrm {batch}}}\) is the number of samples in a given batch and \(\tau _0\) and \(\tau _1\) are hyperparameters of the loss function. The first term corresponds to a cross-entropy loss function over the softmax function applied to the negative spike times (we use negative spike times as the class assignment is determined by the smallest spike time) and encourages an increase of the spike time difference between the label neuron and all other neurons. As the first term depends only on the relative spike times, the second term is a regularization term that encourages early spiking of the label neuron.

Training results are shown in Fig. 3. After training, the test accuracy was 98.1(2)% (mean and standard deviation over 10 different random seeds). This is comparable to the results shown in27, who report 95.9(7)% accuracy with a smaller hidden layer (200 vs. 120 neurons).

MNIST dataset

Figure 4
figure 4

We used EventProp and a two-layer network composed of a hidden layer of leaky integrate-and-fire neurons and a readout layer of non-firing neurons to classify the MNIST dataset, with the readout neuron with the largest voltage deflection determining the class assignment. (A, B) Training results in terms of test error and loss averaged over 10 different random seeds (individual traces shown as grey lines). (C) Confusion matrix after training for a specific random seed and using the test set. (D) Voltage traces of all readout layer neurons for three different samples from the test set, where voltage traces of neurons corresponding to wrong labels are plotted using dashed lines

We encoded each digit of the MNIST dataset38 by transforming each of the \(28\cdot 28=784\) pixels into spike latencies in the interval \([0, t_{\mathrm {max}}]\) (pixels corresponding to a value of 0 or 1 out of 255 were not converted to spikes). The resulting spike patterns were used as input to a two-layer network composed of a hidden layer of leaky integrate-and-fire neurons and a readout layer of non-firing leaky integrator neurons. We used a cross-entropy loss function over the softmax function applied to the voltage maxima of the readout neurons (max-over-time),

$$\begin{aligned} \mathcal {L} = -\frac{1}{N_{\mathrm {batch}}}\sum _{i=1}^{N_{\mathrm {batch}}}\log \left[ \frac{\exp \left( \max _t V_{l(i)}(t)\right) }{\sum _{k=1}^{10} \exp \left( \max _t V_k(t)\right) }\right] , \end{aligned}$$
(4)

where \(V_k(t)\) is the voltage trace of the kth readout neuron, l(i) is the index of the correct label for the ith sample and \(N_{\mathrm {batch}}\) is the number of samples in a given batch. Note that we can write the maximum voltage as \(\max _t V_k(t)=\int V_k(t)\delta (t-t_{\mathrm {max}}){\mathrm {d}}t\) with the time of the maximum \(t_{\mathrm {max}}\) and the Dirac delta \(\delta \), allowing us to apply the chain rule to find the jump of \(\lambda _{V_k}\) (cf. Table 2) at time \(t_{\mathrm {max}}\) (terms containing the distributional derivative of \(\delta \) are always zero).

During training, input spikes were dropped with probability \(p_{\mathrm {drop}}\) in order to avoid overfitting. To obtain a validation set, we extracted and removed 5000 samples from the training set.

Training results are shown in Fig. 4. After training, the test accuracy was 97.6(1)% (mean and standard deviation over 10 different random seeds). This represents competitive classification performance when compared with previously published results using spiking networks with a single, fully connected hidden layer (Table 4).

Table 4 Comparison of previously published classification results on the MNIST dataset for spiking neural networks that are trained using supervised learning with a single, fully connected (non-convolutional) hidden layer and temporal encoding of input data. The second column provides the number of hidden neurons

Discussion

We have derived and provided an algorithm (EventProp) to compute the gradient of a general loss function for a spiking neural network composed of leaky integrate-and-fire neurons. The parameter-dependent spike discontinuities were treated in a well-defined manner using the adjoint method in combination with partial derivative jumps, without approximations or smoothing operations. EventProp uses the resulting adjoint spiking network to backpropagate errors in order to compute the exact gradient. Its forward pass requires computing the spike times of pre-synaptic neurons that transmit spikes to post-synaptic neurons, while the backward pass backpropagates errors at these spike times using the reverse path (i.e., from post-synaptic to pre-synaptic neurons). The rigorous treatment of spike discontinuities in combination with an event-based computation of the exact gradient represent a significant conceptual advance in the study of gradient-based learning methods for spiking neural networks.

An apparent issue with gradient descent based learning in the context of spiking networks is that the magnitude of the gradient diverges at the critical points in parameter space (note the \(\dot{v}^{-1}\) term in the jump term given in Table 2; this term diverges as the membrane potential becomes tangent to the threshold and we have \(\dot{v}\rightarrow 0\)). Indeed, this is a known issue in the broader context of optimal control of dynamical systems with parameter-dependent state transitions4,8. While this divergence can be mitigated using gradient clipping in practice, exact gradients of commonly considered loss functions lead to learning dynamics that are ignorant with respect to these critical points and are therefore unable to selectively recruit additional spikes or dismiss existing spikes. In contrast, surrogate gradient methods continuously transmit errors across neurons and combine these with a non-linear function of the distance of the membrane potential to the threshold. It is therefore plausible that surrogate gradients represent a form of implicit regularization. Neftci et al.3 reports that the surrogate gradient approximates the true gradient in a minimalistic binary classification task while at the same time remaining finite and continuous along an interpolation path in weight space. Hybrid algorithms that combine the exact gradient with explicit regularization techniques could be a direction for future research and provide more principled learning algorithms as compared to ad-hoc replacements of threshold functions.

This work is based on the widely used leaky integrate-and-fire neuron model. Extensions to this model, such as fixed refractory periods, adaptive thresholds or multiple compartments can be treated in an analogous way46. While the absence of explicit solutions to the resulting differential equations can require the use of sophisticated numerical techniques for event-based simulations, such extensions can significantly enhance the computational capabilities of spiking networks. For example17, uses adaptive thresholds to implement LSTM-like memory cells in a recurrent spiking neural network.

Neuromorphic hardware is an increasingly active research subject47,48,49,50,51,52,53,54,55,56,57 and implementing EventProp on such hardware is a natural consideration. The adjoint dynamics as given in Table 2 represent a type of spiking neural network which, instead of spiking dynamically, transmits errors at fixed times \(t^{\text {post}}_{}\) that are scaled with factors \(\dot{v}^{-1}\) retained from the forward pass. Therefore, a neuromorphic implementation could store spike times and scaling factors locally at each neuron, where they could be combined with the dynamic error signal (\(\lambda _V -\lambda _I\) in Table 2) in the backward pass. This requires a possibility to read out neuronal state variables both in the forward and backward pass (membrane potential and synaptic current). The resulting error signals could be distributed across the network using event-based communication schemes similar to, for example, the address-event representation protocol58. As mentioned above, EventProp can be extended to multi-compartment neuron models as used in a recent neuromorphic architecture59.

We used a two-layer feed-forward architecture to demonstrate learning using EventProp. The algorithm can, however, compute the gradient for arbitrary recurrent or convolutional architectures. Its computational and spatial complexity scales linearly with network size (assuming constant average firing rates per neuron), analogous to backpropagation in non-spiking artificial neural networks. The performance in more complex tasks therefore hinges on the general efficacy of gradient-based optimization in spiking networks. As mentioned above, gradients with respect to loss functions defined in terms of spike times or membrane potentials ignores the presence of critical parameters where spikes appear or disappear. We suggest that studying regularization techniques which deal with this fundamental issue in a targeted manner could enable powerful learning algorithms for spiking networks. By providing a theoretical foundation for backpropagation in spiking networks, we support future research that combines such regularization techniques with the computation of exact gradients.

Methods

Partial derivatives in a hybrid system

In the following, we use the example of a bouncing ball (Fig. 1A) to illustrate the calculation of partial derivatives in a dynamical system with state discontinuities. A general treatment of the topic is given in other literature8,60. The discontinuities occurring in the leaky integrate-and-fire neuron are treated analogously in our derivation of the gradient (see corresponding methods subsection).

The differential equation describing the bouncing ball with height y is

$$\begin{aligned} \ddot{y} = -g \end{aligned}$$
(5)

with gravitational acceleration g. Defining the ball’s velocity as \(v\equiv \dot{y}\), this is equivalent to a two-dimensional system

$$\begin{aligned} \dot{v}&= - g, \end{aligned}$$
(6a)
$$\begin{aligned} \dot{y}&= v. \end{aligned}$$
(6b)

The initial conditions are

$$\begin{aligned} v(0)&= 0, \end{aligned}$$
(7a)
$$\begin{aligned} y(0)&= y_0 \end{aligned}$$
(7b)

where \(y_0>0\) is the parameter of interest defining the ball’s initial height. The given equations determine the state trajectory y(t) up to the moment of impact with the ground at \(y=0\). Likewise, the trajectories of the partial derivatives with respect to \(y_0\) are given by differentiation of Eqs. (6) and (7)61,

$$\begin{aligned} \frac{{\mathrm {d}}}{{\mathrm {d}}t}\frac{\partial v}{\partial y_0}&= 0, \end{aligned}$$
(8a)
$$\begin{aligned} \frac{{\mathrm {d}}}{{\mathrm {d}}t}\frac{\partial y}{\partial y_0}&= \frac{\partial v}{\partial y_0}, \end{aligned}$$
(8b)

with initial conditions

$$\begin{aligned} \frac{\partial v}{\partial y_0}(0)&= 0, \end{aligned}$$
(9a)
$$\begin{aligned} \frac{\partial y}{\partial y_0}(0)&= 1. \end{aligned}$$
(9b)

The state discontinuity occurs when the ball hits the ground and we have

$$\begin{aligned} y^- = 0 \end{aligned}$$
(10)

at the time of impact \(t_{{\mathrm {r}}}\). The ball is inelastically reflected, losing a fraction of its energy. Specifically, the system is re-initialized as

$$\begin{aligned} v^+&= -0.8v^-, \end{aligned}$$
(11a)
$$\begin{aligned} y^+&= y^-, \end{aligned}$$
(11b)

where − and \(+\) denote the state before and after the transition (\(v^\pm \), \(y^\pm \) are functions of \(t_r\) and \(y_0\)). Equations (10) and (11) together uniquely determine the partial derivatives after the reflection. The implicit function theorem62 applied to Eq. (10) guarantees (because \(v\ne 0\)) the existence of a function \(t_{{\mathrm {r}}}(y_0)\) that locally describes how the impact time changes with \(y_0\), with its derivative given by

$$\begin{aligned} \frac{{\mathrm {d}}t_{\mathrm {r}}}{{\mathrm {d}}y_0} =-\frac{1}{\dot{y}^-}\frac{\partial y^-}{\partial y_0} = -\frac{1}{v^-}\frac{\partial y^-}{\partial y_0}. \end{aligned}$$
(12)

Likewise, the implicit function theorem applies to Eq. (11) (because \(v\ne 0\), \(\dot{v} \ne 0\)), yielding after differentiation

$$\begin{aligned} \frac{\partial v^+}{\partial y_0} + \dot{v}^+\frac{{\mathrm {d}}t_r}{{\mathrm {d}}y_0}&= \frac{\partial v^-}{\partial y_0} + \dot{v}^-\frac{{\mathrm {d}}t_r}{{\mathrm {d}}y_0}, \end{aligned}$$
(13a)
$$\begin{aligned} \frac{\partial y^+}{\partial y_0} + \dot{y}^+\frac{{\mathrm {d}}t_r}{{\mathrm {d}}y_0}&= \frac{\partial y^-}{\partial y_0} + \dot{y}^-\frac{{\mathrm {d}}t_r}{{\mathrm {d}}y_0}. \end{aligned}$$
(13b)

The partial derivatives after the transition can now be found by solving the system of equations given by Eqs. (11) and (12) and (13),

$$\begin{aligned} \frac{\partial v^+}{\partial y_0}&= -0.8\frac{\partial v^-}{\partial y_0}-1.8g\frac{1}{v^-}\frac{\partial y^-}{\partial y_0}, \end{aligned}$$
(14a)
$$\begin{aligned} \frac{\partial y^+}{\partial y_0}&= -0.8\frac{\partial y^-}{\partial y_0}, \end{aligned}$$
(14b)

where we have used \(\ddot{y} = -g\). Equation (14) provides the initial conditions for the integration of the partial derivatives after the transition; subsequent ground impacts can be treated equivalently. Figure 1A illustrates the behaviour of y(t) and \(\frac{\partial y}{\partial y_0}(t)\) using trajectories calculated numerically using the equations given here.

Adjoint method

We apply the adjoint method to a continuous, first order system of ordinary differential equations and refer the reader to63,64 for a more general setting. Consider an N-dimensional dynamical system \(x: t\mapsto x(t)\in \mathbb {R}^N\) with parameters \(p\in \mathbb {R}^P\) defined by the system of implicit first order ordinary differential equations

$$\begin{aligned} \dot{x} - F(x, p) = 0 \end{aligned}$$
(15)

and constant initial conditions \(G(x(0))=0\) where F, G are smooth vector-valued functions.

We are interested in computing the gradient of a loss that is the integral of a smooth function l over the trajectory of x,

$$\begin{aligned} \mathcal L = \int _0^T l(x, t){\mathrm {d}}t. \end{aligned}$$
(16)

We have

$$\begin{aligned} \frac{{\mathrm {d}}\mathcal {L}}{{\mathrm {d}}p_i} = \int _0^T \frac{\partial l}{\partial x}\cdot \frac{\partial x}{\partial p_i} {\mathrm {d}}t, \end{aligned}$$
(17)

where \(\cdot \) is the dot product and the dynamics of the partial derivatives \(\frac{\partial x}{\partial p_i}\) are given by applying Gronwall’s theorem61,

$$\begin{aligned} \frac{{\mathrm {d}}}{{\mathrm {d}}t}\frac{\partial x}{\partial p_i} = \frac{\partial F}{\partial x}\frac{\partial x}{\partial p_i} + \frac{\partial F}{\partial p_i}. \end{aligned}$$
(18)

Computing x(t) along with \(\frac{\partial x}{\partial p_i}(t)\) using Eqs. (15) and (18) allows us to calculate the gradient in Eq. (17) in a single forward pass. However, this procedure can incur prohibitive computational cost. When considering a recurrent neural network with N neurons and \(P=N^2\) synaptic weights, computing \(\frac{\partial x}{\partial p_i}(t)\) for all parameters requires storing and integrating \(PN=N^3\) partial derivatives.

The adjoint method allows us to avoid computing PN partial derivatives in the forward pass by instead computing N adjoint variables \(\lambda (t)\) in an additional backward pass. We add a Lagrange multiplier \(\lambda : t\mapsto \lambda (t)\in \mathbb {R}^N\) that constrains the system dynamics as given in Eq. (15),

$$\begin{aligned} \mathcal L = \int _0^T\left[ l(x, t)+\lambda \cdot \left( \dot{x} - F(x,p)\right) \right] {\mathrm {d}}t. \end{aligned}$$
(19)

Along trajectories where Eq. (15) holds, \(\lambda \) can be chosen arbitrarily without changing \(\mathcal {L}\) or its derivative. We get

$$\begin{aligned} \frac{{\mathrm {d}}\mathcal {L}}{{\mathrm {d}}p_i} = \int _0^T \left[ \frac{\partial l}{\partial x} \cdot \frac{\partial x}{\partial p_i} +\lambda \cdot \left( \frac{{\mathrm {d}}}{{\mathrm {d}}t}\frac{\partial x}{\partial p_i} - \frac{\partial F}{\partial x}\frac{\partial x}{\partial p_i}- \frac{\partial F}{\partial p_i}\right) \right] {\mathrm {d}}t. \end{aligned}$$
(20)

Using partial integration, we have

$$\begin{aligned} \int _0^T\lambda \cdot \frac{{\mathrm {d}}}{{\mathrm {d}}t}\frac{\partial x}{\partial p_i}{\mathrm {d}}t= -\int _0^T\dot{\lambda }\cdot \frac{\partial x}{\partial p_i}{\mathrm {d}}t+ \left[ \lambda \cdot \frac{\partial x}{\partial p_i} \right] _0^T. \end{aligned}$$
(21)

By setting \(\lambda (T)=0\), the boundary term vanishes because we chose parameter independent initial conditions (\(\frac{\partial x}{\partial p_i}(0)=0\)). The gradient becomes

$$\begin{aligned} \frac{{\mathrm {d}}\mathcal {L}}{{\mathrm {d}}p_i} = \int _0^T \left[ \left( \frac{\partial l}{\partial x} - \dot{\lambda }- \frac{\partial F}{\partial x}\lambda \right) \cdot \frac{\partial x}{\partial p_i} - \lambda \cdot \frac{\partial F}{\partial p_i}\right] {\mathrm {d}}t. \end{aligned}$$
(22)

By choosing \(\lambda \) to fulfill the adjoint differential equation

$$\begin{aligned} \dot{\lambda }= \frac{\partial l}{\partial x} - \frac{\partial F}{\partial x}\lambda \end{aligned}$$
(23)

we are left with

$$\begin{aligned} \frac{{\mathrm {d}}\mathcal {L}}{{\mathrm {d}}p_i} = - \int _0^T\lambda \cdot \frac{\partial F}{\partial p_i}{\mathrm {d}}t. \end{aligned}$$
(24)

The gradient can therefore be computed using Eq. (24), where the adjoint state variable \(\lambda \) is computed from \(t=T\) to \(t=0\) as the solution of the adjoint differential equation Eq. (23) with initial condition \(\lambda (T)=0\). This corresponds to backpropagation through time (BPTT) in discrete time artificial neural networks.

Derivation of gradient

We apply the adjoint method (see previous methods subsection) to the case of a spiking neural network (i.e., a hybrid, discontinuous system with parameter dependent state transitions). The following derivation is specific to the model given in Table 1. A fully general treatment of (adjoint) sensitivity analysis in hybrid systems can be found in8 or10.

The differential equations defining the free dynamics in implicit form are

$$\begin{aligned} f_V&\equiv \tau _\text {mem}\dot{V} +V -I = 0, \end{aligned}$$
(25a)
$$\begin{aligned} f_I&\equiv \tau _\text {syn}\dot{I} +I = 0, \end{aligned}$$
(25b)

where \(f_V\), \(f_I\) are again vectors of size N. We now split up the loss integral in Eq. (1) at the spike times \(t^{\text {post}}_{}\) and use vectors of Lagrange multipliers \(\lambda _V\), \(\lambda _I\) that fix the system dynamics \(f_V\), \(f_I\) between transitions.

$$\begin{aligned} \frac{{\mathrm {d}}\mathcal {L}}{{\mathrm {d}}w_{ji}} = \frac{{\mathrm {d}}}{{\mathrm {d}}w_{ji}}\left[ l_{\mathrm {p}}(t^{\text {post}}_{})+\sum _{k=0}^{N_{\text {post}}} \int _{t^{\text {post}}_{k}}^{t^{\text {post}}_{k+1}} \left[ l_V(V,t) + \lambda _V\cdot f_V + \lambda _I\cdot f_I\right] {\mathrm {d}}t\right] , \end{aligned}$$
(26)

where we set \(t^{\text {post}}_{0}=0\) and \(t^{\text {post}}_{N_{\text {post}}+1}=T\) and \(x\cdot y\) is the dot product of two vectors x, y. Note that because \(f_V\), \(f_I\) vanish along all considered trajectories, \(\lambda _V\) and \(\lambda _I\) can be chosen arbitrarily without changing \(\mathcal {L}\) or its derivative. Using Eq. (25) we have, as per Gronwall’s theorem61,

$$\begin{aligned} \frac{\partial f_V}{\partial w_{ji}}&= \tau _\text {mem}\frac{{\mathrm {d}}}{{\mathrm {d}}t}{\frac{\partial V}{\partial w_{ji}}} + \frac{\partial V}{\partial w_{ji}} - \frac{\partial I}{\partial w_{ji}}, \end{aligned}$$
(27a)
$$\begin{aligned} \frac{\partial f_I}{\partial w_{ji}}&= \tau _\text {syn}\frac{{\mathrm {d}}}{{\mathrm {d}}t}{\frac{\partial I}{\partial w_{ji}}} + \frac{\partial I}{\partial w_{ji}}, \end{aligned}$$
(27b)

where we have used the fact that the derivatives commute, \(\frac{\partial }{\partial w_{ji}} \frac{{\mathrm {d}}}{{\mathrm {d}}t} = \frac{{\mathrm {d}}}{{\mathrm {d}}t} \frac{\partial }{\partial w_{ji}}\) (the weights are fixed and have no time dependence). The gradient then becomes, by application of the Leibniz integral rule,

$$\begin{aligned} \frac{{\mathrm {d}}\mathcal {L}}{{\mathrm {d}}w_{ji}}&= \sum _{k=0}^{N_{\text {post}}} \bigg [\int _{t^{\text {post}}_{k}}^{t^{\text {post}}_{k+1}}\left[ \frac{\partial l_V}{\partial V} \cdot \frac{\partial V}{\partial w_{ji}} + \lambda _V \cdot \left( \tau _\text {mem}\frac{{\mathrm {d}}}{{\mathrm {d}}t}{\frac{\partial V}{\partial w_{ji}}} + \frac{\partial V}{\partial w_{ji}} - \frac{\partial I}{\partial w_{ji}}\right) + \lambda _I \cdot \left( \tau _\text {syn}\frac{{\mathrm {d}}}{{\mathrm {d}}t}{\frac{\partial I}{\partial w_{ji}}} + \frac{\partial I}{\partial w_{ji}}\right) \right] {\mathrm {d}}t\nonumber \\&\quad +\frac{\partial l_{\mathrm {p}}}{\partial t^{\text {post}}_{k}}\frac{{\mathrm {d}}t^{\text {post}}_{k}}{{\mathrm {d}}w_{ji}}+l^-_{V,k+1}\frac{{\mathrm {d}}t^{\text {post}}_{k+1}}{{\mathrm {d}}w_{ji}} - l^+_{V,k}\frac{{\mathrm {d}}t^{\text {post}}_{k}}{{\mathrm {d}}w_{ji}}\bigg ], \end{aligned}$$
(28)

where \(l_{V,k}^\pm \) is the voltage-dependent loss evaluated before (−) or after (\(+\)) the transition and we have used that \(f_V=f_I=0\) along all considered trajectories. Using partial integration, we have

$$\begin{aligned} \int _{t^{\text {post}}_{k}}^{t^{\text {post}}_{k+1}} \lambda _V \cdot \frac{{\mathrm {d}}}{{\mathrm {d}}t}\frac{\partial V}{\partial w_{ji}}{\mathrm {d}}t&= - \int _{t^{\text {post}}_{k}}^{t^{\text {post}}_{k+1}} \dot{\lambda }_V\cdot \frac{\partial V}{\partial w_{ji}}{\mathrm {d}}t+ \bigg [ \lambda _V\cdot \frac{\partial V}{\partial w_{ji}} \bigg ]_{t^{\text {post}}_{k}}^{t^{\text {post}}_{k+1}}, \end{aligned}$$
(29)
$$\begin{aligned} \int _{t^{\text {post}}_{k}}^{t^{\text {post}}_{k+1}} \lambda _I \cdot \frac{{\mathrm {d}}}{{\mathrm {d}}t}\frac{\partial I}{\partial w_{ji}}{\mathrm {d}}t&= - \int _{t^{\text {post}}_{k}}^{t^{\text {post}}_{k+1}} \dot{\lambda }_I\cdot \frac{\partial I}{\partial w_{ji}}{\mathrm {d}}t+ \bigg [ \lambda _I \cdot \frac{\partial I}{\partial w_{ji}} \bigg ]_{t^{\text {post}}_{k}}^{t^{\text {post}}_{k+1}}. \end{aligned}$$
(30)

Collecting terms in \(\frac{\partial V}{\partial w_{ji}}\), \(\frac{\partial I}{\partial w_{ji}}\), we have

$$\begin{aligned} \frac{{\mathrm {d}}\mathcal {L}}{{\mathrm {d}}w_{ji}}&=\sum _{k=0}^{N_{\text {post}}} \bigg [\int _{t^{\text {post}}_{k}}^{t^{\text {post}}_{k+1}}\bigg [\bigg ( \frac{\partial l_V}{\partial V} - \tau _\text {mem}\dot{\lambda }_V + \lambda _V\bigg ) \cdot \frac{\partial V}{\partial w_{ji}} + \left( -\tau _\text {syn}\dot{\lambda }_I + \lambda _I - \lambda _V\right) \cdot \frac{\partial I}{\partial w_{ji}} \bigg ]{\mathrm {d}}t\nonumber \\&\quad +\frac{\partial l_{\mathrm {p}}}{\partial t^{\text {post}}_{k}}\frac{{\mathrm {d}}t^{\text {post}}_{k}}{{\mathrm {d}}w_{ji}} +\tau _\text {mem}\big [ \lambda _V \cdot \frac{\partial V}{\partial w_{ji}} \big ]_{t^{\text {post}}_{k}}^{t^{\text {post}}_{k+1}}+\tau _\text {syn}\big [ \lambda _I \cdot \frac{\partial I}{\partial w_{ji}} \big ]_{t^{\text {post}}_{k}}^{t^{\text {post}}_{k+1}} +l^-_{V,k+1}\frac{{\mathrm {d}}t^{\text {post}}_{k+1}}{{\mathrm {d}}w_{ji}} -l^+_{V,k}\frac{{\mathrm {d}}t^{\text {post}}_{k}}{{\mathrm {d}}w_{ji}}\bigg ]. \end{aligned}$$
(31)

Since the Lagrange multipliers \(\lambda _V(t)\), \(\lambda _I(t)\) can be chosen arbitrarily, this form allows us to set the dynamics of the adjoint variables between transitions. Since the integration of the adjoint variables is done from \(t=T\) to \(t=0\) in practice (i.e., reverse in time), it is practical to transform the time derivative as \(\frac{{\mathrm {d}}}{{\mathrm {d}}t}\rightarrow -\frac{{\mathrm {d}}}{{\mathrm {d}}t}\). Denoting the new time derivative by \('\), we have

$$\begin{aligned} \tau _\text {mem}\lambda _V '&= -\lambda _V - \frac{\partial l_V}{\partial V}, \end{aligned}$$
(32a)
$$\begin{aligned} \tau _\text {syn}\lambda _I '&= -\lambda _I + \lambda _V. \end{aligned}$$
(32b)

The integrand in Eq. (31) therefore vanishes along the trajectory and we are left with a sum over the transitions. Since the initial conditions of V and I are assumed to be parameter independent, we have \(\frac{\partial V}{\partial w_{ji}}=\frac{\partial I}{\partial w_{ji}}=0\) at \(t=0\). We set the initial condition for the adjoint variables to be \(\lambda _V(T)=\lambda _I(T)=0\) to eliminate the boundary term for \(t=T\). We are therefore left with a sum over transitions \(\xi _k\) evaluated at the transition times \(t^{\text {post}}_{k}\),

$$\begin{aligned} \frac{{\mathrm {d}}\mathcal {L}}{{\mathrm {d}}w_{ji}}&=\sum _{k=1}^{N_{\text {post}}} \xi _{k} \end{aligned}$$
(33)

with the definition

$$\begin{aligned} \xi _k&\equiv \frac{\partial l_{\mathrm {p}}}{\partial t^{\text {post}}_{k}}\frac{{\mathrm {d}}t^{\text {post}}_{k}}{{\mathrm {d}}w_{ji}}+l_{V,k}^-\frac{{\mathrm {d}}t^{\text {post}}_{k}}{{\mathrm {d}}w_{ji}} - l_{V,k}^+\frac{{\mathrm {d}}t^{\text {post}}_{k}}{{\mathrm {d}}w_{ji}}\nonumber \\&\quad +\left[ \tau _\text {mem}\left( \lambda _{V}^-\cdot \frac{\partial V^-}{\partial w_{ji}} - \lambda _{V}^+\cdot \frac{\partial V^+}{\partial w_{ji}}\right) + \tau _\text {syn}\left( \lambda _{I}^-\cdot \frac{\partial I^-}{\partial w_{ji}} - \lambda _{I}^+\cdot \frac{\partial I^+}{\partial w_{ji}}\right) \right] \bigg |_{t^{\text {post}}_{k}}. \end{aligned}$$
(34)

We proceed by deriving the relationship between the adjoint variables before and after each transition. Since the computation of the adjoint variables happens in reverse time in practice, we provide \(\lambda ^-\) in terms of \(\lambda ^+\).

Consider a spike caused by the nth neuron, with all other neurons \(m\ne n\) remaining silent. We start by first deriving the relationships between \(\frac{\partial V^+}{\partial w_{ji}}\), \(\frac{\partial V^-}{\partial w_{ji}}\) and \(\frac{\partial I^+}{\partial w_{ji}}\), \(\frac{\partial I^-}{\partial w_{ji}}\).

Membrane potential transition

By considering the relations between \(V^+\), \(V^-\) and \(\dot{V}^+\), \(\dot{V}^-\), we can derive the relation between \(\frac{\partial V^+}{\partial w_{ji}}\) and \(\frac{\partial V^-}{\partial w_{ji}}\) at each spike. Each spike at \(t^{\text {post}}_{}\) is triggered by a neuron’s membrane potential crossing the threshold. We therefore have, at \(t^{\text {post}}_{}\),

$$\begin{aligned} (V^-)_n - \vartheta = 0. \end{aligned}$$
(35)

This relation defines \(t^{\text {post}}_{}\) as a differentiable function of \(w_{ji}\) via the implicit function theorem (illustrated in Fig. 5, see also65), under the condition that \((\dot{V}^-)_n\ne 0\). Differentiation of this relation yields

$$\begin{aligned} \left( \frac{\partial V^-}{\partial w_{ji}}\right) _n + (\dot{V}^-)_n\frac{{\mathrm {d}}t^{\text {post}}_{}}{{\mathrm {d}}w_{ji}} = 0. \end{aligned}$$
(36)

Since we only allow transitions for \((\dot{V}^-)_n\ne 0\), we have

$$\begin{aligned} \frac{{\mathrm {d}}t^{\text {post}}_{}}{{\mathrm {d}}w_{ji}} = -\frac{1}{(\dot{V}^-)_n} \left( \frac{\partial V^-}{\partial w_{ji}}\right) _n. \end{aligned}$$
(37)

Note that corresponding relations were previously used to derive gradient-based learning rules for spiking neuron models20,21,22,26,66; in contrast to the suggestion in20, Eq. (37) is not an approximation but rather an exact relation at all non-critical parameters and invalid at all critical parameters.

Figure 5
figure 5

In this sketch, the relation \(v(t,w)-\vartheta =0\) defines an implicit function (black line along which \({\mathrm {d}}v=0\)). The critical point where the gradient diverges is shown in red

Because the spiking neuron’s membrane potential is reset to zero, we have

$$\begin{aligned} (V^+)_n = 0. \end{aligned}$$
(38)

This implies by differentiation

$$\begin{aligned} \left( \frac{\partial V^+}{\partial w_{ji}}\right) _n + (\dot{V}^+)_n\frac{{\mathrm {d}}t^{\text {post}}_{}}{{\mathrm {d}}w_{ji}}&= 0. \end{aligned}$$
(39)

Using Eq. (37), this allows us to relate the partial derivative after the spike to the partial derivative before the spike,

$$\begin{aligned} \left( \frac{\partial V^+}{\partial w_{ji}}\right) _n&= \frac{(\dot{V}^+)_n}{(\dot{V}^-)_n} \left( \frac{\partial V^-}{\partial w_{ji}}\right) _n. \end{aligned}$$
(40)

Since we have \((V^+)_m = (V^-)_m\) for all other, non-spiking neurons \(m\ne n\), it holds that

$$\begin{aligned} \left( \frac{\partial V^+}{\partial w_{ji}}\right) _m + (\dot{V}^+)_m \frac{{\mathrm {d}}t^{\text {post}}_{}}{{\mathrm {d}}w_{ji}} = \left( \frac{\partial V^-}{\partial w_{ji}}\right) _m + (\dot{V}^-)_m \frac{{\mathrm {d}}t^{\text {post}}_{}}{{\mathrm {d}}w_{ji}}. \end{aligned}$$
(41)

Because the spiking neuron n causes the synaptic current of all neurons \(m\ne n\) to jump by \(w_{mn}\), we have

$$\begin{aligned} \tau _\text {mem}(\dot{V}^+)_m = \tau _\text {mem}(\dot{V}^-)_m + w_{mn} \end{aligned}$$
(42)

and therefore get with Eq. (36)

$$\begin{aligned} \left( \frac{\partial V^+}{\partial w_{ji}}\right) _m&= \left( \frac{\partial V^-}{\partial w_{ji}}\right) _m - \tau _\text {mem}^{-1}w_{mn} \frac{{\mathrm {d}}t^{\text {post}}_{}}{{\mathrm {d}}w_{ji}} \end{aligned}$$
(43)
$$\begin{aligned}&= \left( \frac{\partial V^-}{\partial w_{ji}}\right) _m + \frac{1}{\tau _\text {mem}(\dot{V}^-)_n}w_{mn}\left( \frac{\partial V^-}{\partial w_{ji}}\right) _n. \end{aligned}$$
(44)

Synaptic current transition

The spiking neuron n causes the synaptic current of all neurons \(m\ne n\) to jump by the corresponding weight \(w_{mn}\). We therefore have

$$\begin{aligned} (I^+)_m = (I^-)_m + w_{mn}. \end{aligned}$$
(45)

By differentiation, this relation implies the consistency equations for the partial derivatives \(\frac{\partial I}{\partial w_{ji}}\) with respect to the considered weight \(w_{ji}\),

$$\begin{aligned} \left( \frac{\partial I^+}{\partial w_{ji}}\right) _m + (\dot{I}^+)_m\frac{{\mathrm {d}}t^{\text {post}}_{}}{{\mathrm {d}}w_{ji}} = \left( \frac{\partial I^-}{\partial w_{ji}}\right) _m + (\dot{I}^-)_m\frac{{\mathrm {d}}t^{\text {post}}_{}}{{\mathrm {d}}w_{ji}} + \delta _{in} \delta _{jm}, \end{aligned}$$
(46)

where \(\delta _{ji}\) is the Kronecker delta. Because

$$\begin{aligned} \tau _\text {syn}(\dot{I}^+)_m = \tau _\text {syn}(\dot{I}^-)_m - w_{mn}, \end{aligned}$$
(47)

we get with Eq. (36)

$$\begin{aligned} \left( \frac{\partial I^+}{\partial w_{ji}}\right) _m&= \left( \frac{\partial I^-}{\partial w_{ji}}\right) _m + \tau _\text {syn}^{-1}w_{mn}\frac{{\mathrm {d}}t^{\text {post}}_{}}{{\mathrm {d}}w_{ji}}+ \delta _{in} \delta _{jm} \end{aligned}$$
(48)
$$\begin{aligned}&= \left( \frac{\partial I^-}{\partial w_{ji}}\right) _m - \frac{1}{\tau _\text {syn}(\dot{V}^-)_n} w_{mn}\left( \frac{\partial V^-}{\partial w_{ji}}\right) _n+ \delta _{in} \delta _{jm}. \end{aligned}$$
(49)

With \((I^+)_n = (I^-)_n\) and \((\dot{I}^+)_n = (\dot{I}^-)_n\), we have

$$\begin{aligned} \left( \frac{\partial I^+}{\partial w_{ji}}\right) _n = \left( \frac{\partial I^-}{\partial w_{ji}}\right) _n. \end{aligned}$$
(50)

Using the relations of the partial derivatives from Eqs. (37), (40), (44), (49) and (50) in the transition equation Eq. (34), we now derive relations between the adjoint variables. Collecting terms in the partial derivatives and writing the index of the spiking neuron for the kth spike as n(k), we have

$$\begin{aligned} \xi _k&= \bigg [\sum _{m\ne n(k)} \bigg [\tau _\text {mem}(\lambda _V^- - \lambda _V^+)_m\left( \frac{\partial V^-}{\partial w_{ji}}\right) _m + \tau _\text {syn}(\lambda _I^- - \lambda _I^+)_m\left( \frac{\partial I^-}{\partial w_{ji}}\right) _m -\tau _\text {syn}\delta _{in(k)}\delta _{jm}(\lambda _I^+)_m \bigg ]\nonumber \\&\quad + \left( \frac{\partial V^-}{\partial w_{ji}}\right) _{n(k)}\left[ \tau _\text {mem}\left( \lambda _V^- - \frac{(\dot{V}^+)_{n(k)}}{(\dot{V}^-)_{n(k)}}\lambda _V^+\right) _{n(k)} \right. \nonumber \\&\quad \left. +\frac{1}{ (\dot{V}^-)_{n(k)}}\left( \sum _{m\ne n(k)}w_{n(k)m}(\lambda _I^+-\lambda _V^+)_m -\frac{\partial l_{\mathrm {p}}}{\partial t^{\text {post}}_{k}} + l_V^+ - l_V^-\right) \right] \nonumber \\&\quad + \tau _\text {syn}(\lambda _I^- - \lambda _I^+)\left( \frac{\partial I^-}{\partial w_{ji}}\right) _{n(k)}\bigg ]\bigg |_{t^{\text {post}}_{k}}. \end{aligned}$$
(51)

This form dictates the jumps of the adjoint variables for the spiking neuron n and all other, silent neurons m,

$$\begin{aligned} (\lambda _V^-)_n&= \frac{(\dot{V}^+)_n}{(\dot{V}^-)_n}(\lambda _V^+)_n + \frac{1}{\tau _\text {mem}(\dot{V}^-)_n}\left[ \sum _{m\ne n} w_{mn}(\lambda _V^+ - \lambda _I^+)_m +\frac{\partial l_{\mathrm {p}}}{\partial t^{\text {post}}_{k}}+l_V^--l_V^+\right] , \end{aligned}$$
(52a)
$$\begin{aligned} (\lambda _V^-)_m&= (\lambda _V^+)_m, \end{aligned}$$
(52b)
$$\begin{aligned} \lambda _I^-&= \lambda _I^+. \end{aligned}$$
(52c)

With these jumps, the gradient reduces to

$$\begin{aligned} \frac{{\mathrm {d}}\mathcal {L}}{{\mathrm {d}}w_{ji}}&=-\tau _\text {syn}\sum _{k=1}^{N_{\text {post}}} \delta _{in(k)} (\lambda _I)_j \end{aligned}$$
(53)
$$\begin{aligned}&=-\tau _\text {syn}\sum _{\text {spikes from }i}(\lambda _I)_j. \end{aligned}$$
(54)

Summary

The free adjoint dynamics between spikes are given by Eq. (32) while spikes cause jumps given by Eq. (52). The gradient for a given weight samples the post-synaptic neuron’s \(\lambda _I\) when spikes are transmitted across the corresponding synapse [Eq. (53)]. Since we can identify, with \((\dot{V}^+)_n -(\dot{V}^-)_n=\tau _\text {mem}^{-1}\vartheta \),

$$\begin{aligned} \frac{(\dot{V}^+)_n}{(\dot{V}^-)_n} = \frac{(\dot{V}^+)_n -(\dot{V}^-)_n}{(\dot{V}^-)_n} + 1 = \frac{\vartheta }{\tau _\text {mem}(\dot{V}^-)_n} + 1 \end{aligned}$$
(55)

the derived solution is equivalent to Eq. (2) and Table 2.

Fixed Input Spikes

If a given neuron i is subjected to a fixed pre-synaptic spike train across a synapse with weight \(w_\text {input}\), the transition times are fixed and the adjoint variables do not experience jumps. The gradient simply samples the neuron’s \(\lambda _I\) at the times of spike arrival,

$$\begin{aligned} \frac{{\mathrm {d}}\mathcal {L}}{{\mathrm {d}}w_\text {input}} = -\tau _\text {syn}\sum _{\text {input spikes}}(\lambda _I)_i. \end{aligned}$$
(56)

Coincident spikes

The derivation above assumes that only a single neuron of the recurrent network spikes at a given \(t^{\text {post}}_{k}\). In general, coincident spikes may occur. If neurons a and b spike at the same time and the times of their respective threshold crossing vary independently as function of \(w_{ji}\), the derivation above still holds, with both neuron’s \(\lambda _V\) experiencing a jump as in Eq. (52a).