## Calcium dynamics and STDP

In [31]:
%matplotlib inline
import matplotlib.pyplot as plt
from numpy import *
from IPython.html import widgets
from matplotlib import rcParams
import matplotlib.gridspec as gridspec

class timeAboveThreshold():
        ''' 
            class to calculate the fraction of time \alpha the calcium trace spends above threshold 
        '''
        ###############################################################################
        def __init__(self, tauCa, Cpre, Cpost, thetaD, thetaP, nonlinear=1.):
                self.tauCa = tauCa
                self.Cpre = Cpre
                self.Cpost = Cpost
                self.thetaD  = thetaD
                self.thetaP  = thetaP
                # determine eta based on nonlinearity factor and amplitudes
                self.nonlinear = nonlinear
                self.eta = (self.nonlinear*(self.Cpost + self.Cpre) - self.Cpost)/self.Cpre - 1.
                
        def spikePairFrequency(self,deltaT,frequency=1):
                #
                
                interval = 1./frequency
                
                timeAbove = zeros(2)
                
                # in case deltaT is larger then one interval
                if ( fabs(deltaT) > 1./frequency ):
                        deltaT = -(fabs(deltaT) - 1./frequency)
                
                # determine amplitude of the discontinous points of the calcium trace
                # post-pre
                if ( exp(1./(frequency*self.tauCa)) == NaN ) :
                        A = 0.
                        B = self.Cpost*exp(-fabs(deltaT)/self.tauCa)
                else :
                        A  = (self.Cpost + self.Cpre*exp(fabs(deltaT)/self.tauCa))/(exp(1./(frequency*self.tauCa)) - 1.)
                        B  = (self.Cpre + self.Cpost*exp((1./frequency -fabs(deltaT))/self.tauCa))/(exp(1./(frequency*self.tauCa)) - 1.)
                C  = A + self.Cpost
                D  = B + self.Cpre

                # pre-post
                if ( exp(1./(frequency*self.tauCa)) == NaN ) :
                        E = 0.
                        F = self.Cpre*exp(-fabs(deltaT)/self.tauCa)
                else :
                        E  = (self.Cpre + self.Cpost*exp(fabs(deltaT)/self.tauCa))/(exp(1./(frequency*self.tauCa)) - 1.)
                        F  = (self.Cpost + self.Cpre*exp((1./frequency -fabs(deltaT))/self.tauCa))/(exp(1./(frequency*self.tauCa)) - 1.)
                        
                G  = E + self.Cpre
                H  = F + self.Cpost
                
                # loop over depression and potentiation threshold
                for i in range(2):
                        if i == 0:
                                Ct = self.thetaD
                        elif i==1:
                                Ct = self.thetaP
                        # post-pre
                        if (deltaT < 0.):
                                if ( A <= Ct and B > Ct ) :
                                        I = self.tauCa*log(D/Ct) + fabs(deltaT)
                                elif ( A > Ct and B > Ct) :
                                        I = 1./frequency
                                elif ( A > Ct and B <= Ct ) :
                                        I = self.tauCa*log(C/Ct) + fabs(1./frequency) - fabs(deltaT)
                                elif ( A <= Ct and B <= Ct and D > Ct and C > Ct ) :
                                        I = self.tauCa*log(C/Ct) + self.tauCa*log(D/Ct)
                                elif ( A <= Ct and B <= Ct and D <= Ct and C > Ct ) :
                                        I = self.tauCa*log(C/Ct)
                                elif( A <= Ct and B <= Ct and D > Ct and C <= Ct ) :
                                        I = self.tauCa*log(D/Ct)
                                elif ( A <= Ct and B <= Ct and D <= Ct and C <= Ct ) :
                                        I = 0.
                                else :
                                        print(A, B, C, D, Ct, frequency, deltaT)
                                        print("post-pre : Problem in spikePairFrequency!")
                                        sys.exit(1)
                        # pre-post
                        else:
                                if ( E <= Ct and F > Ct ) :
                                        I = self.tauCa*log(H/Ct) + fabs(deltaT)
                                elif ( E > Ct and F > Ct) :
                                        I = 1./frequency
                                elif ( E > Ct and F <= Ct ) :
                                        I = self.tauCa*log(G/Ct) + fabs(1./frequency) - fabs(deltaT)
                                elif ( E <= Ct and F <= Ct and G > Ct and H > Ct ) :
                                        I = self.tauCa*log(G/Ct) + self.tauCa*log(H/Ct)
                                elif ( E <= Ct and F <= Ct and H <= Ct and G > Ct ) :
                                        I = self.tauCa*log(G/Ct)
                                elif ( E <= Ct and F <= Ct and H > Ct and G <= Ct ) :
                                        I = self.tauCa*log(H/Ct)
                                elif ( E <= Ct and F <= Ct and G <= Ct and H <= Ct ) :
                                        I = 0.
                                else :
                                        print(E, F, G, H, Ct, frequency, deltaT)
                                        print("pre-post : Problem in spikePairFrequency! ") 
                                        sys.exit(1)
                        #
                        timeAbove[i] = I
                #
                alphaD = timeAbove[0]/interval
                alphaP = timeAbove[1]/interval
                return (alphaD,alphaP)

def synapticChange(GammaD,GammaP):
    wbar = GammaP/(GammaD + GammaP)
    tauEff = tau/(GammaD + GammaP)
    sc = wbar - (wbar - 0.5)*exp(-T_total/tauEff)
    #tauEff = tau/(gammaP*tATArray[:,1]+gammaD*tATArray[:,0])
    #wbarPoint = dd[1]*gammaP/(dd[0]*gammaD+dd[1]*gammaP)
    return sc
    
# calcium dynamics parameters
case = 'shouval'
tauca = 0.02 # calcium time constant in sec
c0 = 0       # initial calcium concentration
thetaD = 1.  # depression threshold
thetaP = 1.5 # potentiation threshold
T_total = 60.
tau = 2.

if case=='shouval':
    Cpre  = 1.   # presynaptically induced amplitude 
    Cpost = 0.9   # postsynaptically induced amplitude
    gammaP = 6.
    gammaD = 1.
    D = 0.
elif case=='graupner':
    Cpre = 1.
    Cpost= 2.
    gammaD = 1.
    gammaP = log(Cpost/thetaD)/log(Cpost/thetaP)
    D =0.
    



tAT = timeAboveThreshold(tauca,Cpre,Cpost,thetaD,thetaP)

DT = linspace(-0.15,0.15,int((0.3)/0.001+1))
tATArray = zeros((len(DT),2))
for i in range(len(DT)):
    tATArray[i] = tAT.spikePairFrequency(DT[i]-D)
# 
tStart = -0.15
tEnd   =  0.15
dt     = 0.0001
t = linspace(tStart,tEnd,int((tEnd-tStart)/dt + 1))

def pltCaTimeAboveThreshold(DeltaT):
    # calculate time above threshold
    dd = tAT.spikePairFrequency(DeltaT/1000.-D)
    # create calcium traces
        
    DTms = DeltaT/1000.
    Ca = zeros(len(t))
    Ca[t>0]+= Cpre*exp(-t[t>0]/tauca)
    Ca[t>DTms]+=Cpost*exp(-(t[t>DTms]-DTms)/tauca)
        
    # produce figure
    fig_width = 12# width in inches
    fig_height = 8  # height in inches
    fig_size =  [fig_width,fig_height]
    params = {'font.size': 16,
              'figure.figsize': fig_size}
    rcParams.update(params)
    fig = plt.figure()
    gs = gridspec.GridSpec(2, 1,
                           #width_ratios=[1,1.2],
                           #height_ratios=[1,1]
                           )

    # define vertical and horizontal spacing between panels
    gs.update(wspace=0.3,hspace=0.3)

    ax0 = plt.subplot(gs[0])
    #for i in range(len(Deltat)):
    ax0.plot(t*1000.,Ca,lw=2,clip_on=False)
    ax0.vlines(DeltaT,(Cpre+Cpost)+0.1,(Cpre+Cpost)+0.3,color='blue')
    ax0.vlines(0,(Cpre+Cpost)+0.4,(Cpre+Cpost)+0.6,color='green')
    
    ax0.axhline(y=thetaD,ls='--',c='C2')
    ax0.axhline(y=thetaP,ls='--',c='C3')
    ax0.set_xlabel('time (ms)')
    ax0.set_ylabel('calcium')
    ax0.set_ylim(0,(Cpre+Cpost)+0.6)
    
    ax0.spines['top'].set_visible(False)
    ax0.spines['right'].set_visible(False)
    ax0.spines['bottom'].set_position(('outward', 10))
    ax0.spines['left'].set_position(('outward', 10))
    ax0.yaxis.set_ticks_position('left')
    ax0.xaxis.set_ticks_position('bottom')
    
    textPre = ax0.annotate(r'pre-spike', xy=(2,(Cpre+Cpost)+0.45), annotation_clip=False,
           xytext=None, textcoords='data',fontsize=12,color='green',
           arrowprops=None
           )
    textPost = ax0.annotate(r'post-spike', xy=(DeltaT+2,(Cpre+Cpost)+0.15), annotation_clip=False,
           xytext=None, textcoords='data',fontsize=12,color='blue',
           arrowprops=None
           )
    
    gssub = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=gs[1],wspace=0.4)

    ax1 = plt.subplot(gssub[0])
    #ax1.plot(Deltat,timeAboveThreshold[:,0],'o',c='orange')
    #ax1.plot(Deltat,timeAboveThreshold[:,1],'o',c='turquoise')
    #ax1.axhline(y=1,ls='--',c='orange')
    #ax1.axhline(y=1.5,ls='--',c='turquoise')
    ax1.set_xlabel(r'$\Delta t$ (ms)')
    ax1.set_ylabel('time above threshold (ms)')
    
    ax1.plot(DT*1000.,tATArray[:,0]*1000.,lw=2,c='C2')
    ax1.plot(DT*1000.,tATArray[:,1]*1000.,lw=2,c='C3')
    ax1.plot(DeltaT,dd[0]*1000.,'^',ms=8,c='C2')
    ax1.plot(DeltaT,dd[1]*1000.,'^',ms=8,c='C3')
    ax1.set_xlim(-160,160)
    #ax1.set_ylim(0,30.)
    
    ax1.spines['top'].set_visible(False)
    ax1.spines['right'].set_visible(False)
    ax1.spines['bottom'].set_position(('outward', 10))
    ax1.spines['left'].set_position(('outward', 10))
    ax1.yaxis.set_ticks_position('left')
    ax1.xaxis.set_ticks_position('bottom')
    
    
    ax2 = plt.subplot(gssub[1])
    #ax1.plot(Deltat,timeAboveThreshold[:,0],'o',c='orange')
    #ax1.plot(Deltat,timeAboveThreshold[:,1],'o',c='turquoise')
    #ax1.axhline(y=1,ls='--',c='orange')
    #ax1.axhline(y=1.5,ls='--',c='turquoise')
    ax2.set_xlabel(r'$\Delta t$ (ms)')
    ax2.set_ylabel('synaptic change')
    #mask = (tATArray[:,1]!=0) #|(tATArray[:,0]!=0)
    #wbar = gammaP*tATArray[:,1]/(gammaP*tATArray[:,1]+gammaD*tATArray[:,0])
    #tauEff = tau/(gammaP*tATArray[:,1]+gammaD*tATArray[:,0])
    #wbarPoint = dd[1]*gammaP/(dd[0]*gammaD+dd[1]*gammaP)
    synChange = synapticChange(gammaD*tATArray[:,0],gammaP*tATArray[:,1])
    synChangePoint = synapticChange(dd[0]*gammaD,dd[1]*gammaP) #wbar - (wbar - 0.5)*exp(-T_total/tauEff)
    
    ax2.plot(DT*1000.,synChange/0.5,lw=2,c='black')
    ax2.plot(DeltaT,synChangePoint/0.5,'^',ms=8,c='C0')
    ax2.axhline(y=1,ls='--',c='0.5')
    
    ax2.spines['top'].set_visible(False)
    ax2.spines['right'].set_visible(False)
    ax2.spines['bottom'].set_position(('outward', 10))
    ax2.spines['left'].set_position(('outward', 10))
    ax2.yaxis.set_ticks_position('left')
    ax2.xaxis.set_ticks_position('bottom')
    
    #ax2.set_xlim(-160,160)
    #ax2.set_ylim(0.75,1.25)
    #plt.plot(t,np.sin(2*np.pi*t*DeltaT))
    plt.show()


widgets.interact(pltCaTimeAboveThreshold, DeltaT=(-150,150,5))

interactive(children=(IntSlider(value=0, description='DeltaT', max=150, min=-150, step=5), Output()), _dom_claâ€¦

<function __main__.pltCaTimeAboveThreshold(DeltaT)>