# 2013/SKIMonNengo: testskim.py

File testskim.py, 2.5 KB (added by drasmuss, 5 years ago) |
---|

Line | |
---|---|

1 | import sys |

2 | sys.path.append("location of current directory") |

3 | |

4 | import nef, skimnetwork, math, random |

5 | from ca.nengo.util import MU |

6 | |

7 | |

8 | def create_noise(t, D, dt=0.001, sd=0.9): |

9 | """Creates a random pattern, used to generate |

10 | examples of spike patterns we want to detect.""" |

11 | timesteps = int(t/dt) |

12 | p = [[0.0 for _ in range(D)] for _ in range(timesteps)] |

13 | |

14 | for i in range(D): |

15 | p[random.randrange(timesteps)][i] = 1 |

16 | |

17 | return p |

18 | |

19 | class PatternedInput: |

20 | """Creates a random input signal with certain patterns inserted at |

21 | the specified times.""" |

22 | def __init__(self, dimensions, pattern_times, pattern_length=0.01, T=1, dt=0.001): |

23 | self.pattern = create_noise(pattern_length, dimensions, dt=dt) |

24 | timesteps = int(T/dt) |

25 | self.data = [[random.gauss(0.0, 0.01) for _ in range(dimensions)] for _ in range(timesteps)] |

26 | self.dt = dt |

27 | self.correct = [[0.0] for _ in range(timesteps)] |

28 | signallength=30 |

29 | for t in pattern_times: |

30 | index = int(t/dt) - len(self.pattern) - 1 |

31 | if index >= 0 and index+len(self.pattern)+signallength < len(self.correct): |

32 | self.correct[index+len(self.pattern):index+len(self.pattern)+signallength] = [[1.0] for _ in range(signallength)] |

33 | self.data[index:index+len(self.pattern)] = self.pattern |

34 | |

35 | def __call__(self, t): |

36 | index = int(t/self.dt) |

37 | return self.data[index%len(self.data)] |

38 | |

39 | |

40 | net = nef.Network("testSKIM") |

41 | |

42 | N = 80 #number of dendrites per output neuron |

43 | D = 5 #number of input neurons per dimension of the input signal |

44 | T = 4 #the length of the signal |

45 | dt = 0.001 #timestep |

46 | pattern_length = 0.1 #length of the patterns in the signal |

47 | event_rate = 5 #number of patterns per second |

48 | threshold=0.35 #threshold used for pattern detection |

49 | |

50 | #generate random inputs |

51 | #here we generate two different input signals and concatenate them |

52 | event_times = [random.uniform(0,T) for i in range(int(event_rate*T))] |

53 | event_times2 = [random.uniform(0,T) for i in range(int(event_rate*T))] |

54 | |

55 | pi = PatternedInput(D, event_times, T=T, pattern_length=pattern_length) |

56 | pi2 = PatternedInput(D, event_times2, T=T, pattern_length=pattern_length) |

57 | |

58 | inputs = MU.transpose(pi.data) + MU.transpose(pi2.data) |

59 | outputs = MU.transpose(pi.correct) + MU.transpose(pi2.correct) |

60 | |

61 | #create SKIM network |

62 | net.add(skimnetwork.SKIMNetwork(N, inputs, outputs, dt=dt, threshold=threshold)) |

63 | net.add_to_nengo() |