UESMANN CPP  1.0
Reference implementation of UESMANN
/home/travis/build/jimfinnis/uesmanncpp/net.hpp
Go to the documentation of this file.
1 
8 #ifndef __NET_HPP
9 #define __NET_HPP
10 
11 #include <math.h>
12 
13 #include "netType.hpp"
14 #include "data.hpp"
15 
20 inline double sigmoid(double x){
21  return 1.0/(1.0+exp(-x));
22 }
23 
28 inline double sigmoidDiff(double x){
29  double s = sigmoid(x);
30  return (1.0-s)*s;
31 }
32 
39 class Net {
40  friend class OutputBlendingNet;
41  friend class HInputNet;
42 public:
43 
47  virtual ~Net() {}
48 
50  drand48_data rd;
51 
57  void setSeed(long seed){
58  srand48_r(seed,&rd);
59  }
60 
65  virtual int getLayerSize(int n) const =0;
66 
70  virtual int getLayerCount() const =0;
71 
75  int getInputCount() const {
76  return getLayerSize(0);
77  }
78 
82  int getOutputCount() const {
83  return getLayerSize(getLayerCount()-1);
84  }
85 
86 
87 
93  virtual void setInputs(double *d) = 0;
94 
100  virtual double *getOutputs() const = 0;
101 
107  double *run(double *in) {
108  setInputs(in);
109  update();
110  return getOutputs();
111  }
112 
117  virtual void setH(double h)=0;
118 
122  virtual double getH() const =0;
123 
142  double test(ExampleSet& examples,int start=0,int num=-1){
143  double mseSum = 0;
144  // have to do this here, too, although runExamples does it, so we can
145  // get the denominator for the mse.
146  if(num<0)num=examples.getCount()-start;
147 
148  // for each example, run it and accumulate the sum of squared errors
149  // on all outputs
150 
151  for(int i=0;i<num;i++){
152  int idx = start+i;
153  setH(examples.getH(idx));
154  double *netout = run(examples.getInputs(idx));
155  double *exout = examples.getOutputs(idx);
156  for(int j=0;j<examples.getOutputCount();j++){
157  double d = netout[j]-exout[j];
158  mseSum += d*d;
159  }
160  }
161 
162  // we then divide by the number of examples and the output count.
163  return mseSum / (num * examples.getOutputCount());
164  }
165 
173  struct SGDParams {
174  friend class Net;
175 
182 
186  double eta;
187 
188 
192  int nSlices;
193 
198 
204 
210  SGDParams& crossValidationManual(int slices,int nperslice,int interval){
211  nSlices = slices;
212  nPerSlice = nperslice;
213  cvInterval = interval;
214  return *this;
215  }
216 
222 
225  shuffleMode = m;
226  return *this;
227  }
228 
229 
230 
237 
240  selectBestWithCV=v;
241  return *this;
242  }
243 
244 
249  bool cvShuffle;
250 
252  SGDParams& setCVShuffle(bool v=true){
253  cvShuffle=v;
254  return *this;
255  }
256 
261 
263  SGDParams& setInitRange(double range=-1){
264  initrange = range;
265  return *this;
266  }
267 
272  long seed;
273 
275  SGDParams& setSeed(long v){
276  seed = v;
277  return *this;
278  }
279 
284  double *bestNetBuffer;
285 
290 
291  private:
298  void init(double _eta,int _iters){
299  seed = 0L;
300  eta = _eta;
301  iterations = _iters;
302  initrange = -1;
303  bestNetBuffer = NULL;
304  ownsBestNetBuffer = false;
305  storeBestNet = false;
306  nSlices=0;
307  nPerSlice=0;
308  cvInterval=1;
309  shuffleMode = ExampleSet::STRIDE;
310  selectBestWithCV=false; // there might not be CV!
311  cvShuffle = true; // do shuffle CV at the end of an epoch
312  }
313  public:
323  SGDParams(double _eta, int _iters) {
324  init(_eta,_iters);
325  }
326 
332  SGDParams(double _eta,const ExampleSet& examples,int _iters){
333  init(_eta,examples.getCount()*_iters);
334  }
335 
341  if(ownsBestNetBuffer)delete[] bestNetBuffer;
342  }
343 
357  double propCV,
358  int cvCount,
359  int cvSlices,
360  bool cvShuf=true
361  ){
362  cvShuffle = cvShuf;
363  // calculate the number of CV examples
364  int nCV = (int)round(propCV*examples.getCount());
365  if(nCV==0 || nCV>examples.getCount())
366  throw std::out_of_range("Bad cross-validation count");
367  if(cvSlices<=0)
368  throw std::out_of_range("Zero (or fewer) CV slices is a bad thing");
369  // calculate the number of examples per slice and check it's not zero.
370  // The resulting number of CV examples may not agree with nCV above due
371  // to the integer division
372  nPerSlice = nCV/cvSlices;
373  nSlices = cvSlices;
374  if(!nPerSlice)
375  throw std::logic_error("Too many slices");
376  // calculate the cvInterval
377  cvInterval = iterations/cvCount;
378  if(cvInterval<=0)
379  throw std::logic_error("Too many CV events");
380  // we want to pick the best network by CV rather than training error
381  selectBestWithCV=true;
382 
383  printf("Cross-validation: %d slices, %d items per slice, %d total\n",
384  nSlices,nPerSlice,nSlices*nPerSlice);
385  return *this;
386  }
387 
395  ownsBestNetBuffer = true;
396  storeBestNet = true;
397  return *this;
398  }
399  private:
404  bool ownsBestNetBuffer;
405  };
406 
407 
408 
428  double trainSGD(ExampleSet &examples,SGDParams& params){
429 
430  // set seed for PRNG
431  setSeed(params.seed);
432 
433  // separate out the training examples from the cross-validation examples
434  int nCV = params.nSlices*params.nPerSlice;
435  // it's an error if there are too many CV examples
436  if(nCV>=examples.getCount())
437  throw std::out_of_range("Too many cross-validation examples");
438 
439  if(!nCV && params.selectBestWithCV)
440  throw std::logic_error("cannot use CV to select best when no CV is done");
441 
442  // get the number of actual training examples
443  int nExamples = examples.getCount() - nCV;
444 
445  // initialise the network
446  initWeights(params.initrange);
447 
448  // initialise minimum error to rogue value
449  double minError = -1;
450 
451  // We don't shuffle before getting the cross-validation examples,
452  // because in some cases there's a kind of "fake" cv going on where the
453  // training portion and cv portion have to have similar (or identical)
454  // distributions. See the boolean test code for an example.
455  // examples.shuffle(&rd,params.shuffleMode);
456 
457  // build a temporary subset for the CV examples. This still needs to exist
458  // even if we're not using CV, so in that case we'll just
459  // use a dummy of one example.
460 
461  ExampleSet cvExamples(examples,nCV?examples.getCount()-nCV:0,nCV?nCV:1);
462 
463 
464  // setup a countdown for when we cross-validate
465  int cvCountdown = params.cvInterval;
466  // and which slice we are doing
467  int cvSlice = 0;
468 
469  // now actually do the training
470 
471  FILE *log = fopen("foo","w");
472  fprintf(log,"x,slice,y\n");
473  for(int i=0;i<params.iterations;i++){
474  // find the example number
475  int exampleIndex = i % nExamples;
476 
477  // at the start of each epoch, reshuffle. This will effectively do an extra shuffle
478  // as we've already done it once at the start, before splitting out the CV examples.
479 
480  if(exampleIndex == 0)
481  examples.shuffle(&rd,params.shuffleMode,nExamples);
482 
483  // train here, just one example, no batching.
484  double trainingError = trainBatch(examples,exampleIndex,1,params.eta);
485 
486  if(!params.selectBestWithCV){
487  // now test the error and keep the best net. This works differently
488  // if we're doing this by cross-validation or training error. Here
489  // we're using the training error.
490  if(minError < 0 || trainingError < minError){
491  if(params.storeBestNet){
492  if(!params.bestNetBuffer)
493  params.bestNetBuffer = new double[getDataSize()];
494  save(params.bestNetBuffer);
495  }
496  minError = trainingError;
497  }
498  }
499 
500  // is there cross-validation? If so, do it.
501 
502  if(nCV && !--cvCountdown){
503  cvCountdown = params.cvInterval; // reset
504 
505  // test the appropriate slice, from example cvSlice*nPerSlice, length nPerSlice,
506  // and get the MSE
507  double error = test(cvExamples,cvSlice*params.nPerSlice,
508  params.nPerSlice);
509  fprintf(log,"%d,%d,%f\n",i,cvSlice,error);
510 
511  // test this against the min error as was done above
512  if(params.selectBestWithCV){
513  if(minError < 0 || trainingError < minError){
514  if(params.storeBestNet){
515  if(!params.bestNetBuffer)
516  params.bestNetBuffer = new double[getDataSize()];
517  save(params.bestNetBuffer);
518  }
519  minError = trainingError;
520  }
521  }
522 
523  // increment the slice index
524  cvSlice = (cvSlice+1)%params.nSlices;
525  // if we are now on the first slice, shuffle the entire CV set
526  if(!cvSlice && params.cvShuffle)
527  cvExamples.shuffle(&rd,params.shuffleMode);
528  }
529  }
530 
531  fclose(log);
532 
533  // at the end, finalise the network to the best found if we can
534  if(params.bestNetBuffer)
535  load(params.bestNetBuffer);
536 
537  // test on either the entire CV set or the training set and return result
538  return test(nCV?cvExamples:examples);
539  }
540 
546  virtual int getDataSize() const = 0;
547 
553  virtual void save(double *buf) const = 0;
554 
561  virtual void load(double *buf) = 0;
562 
563 protected:
564 
565 
566 
573  virtual void update() = 0;
574 
581  type = tp;
582  setSeed(0);
583  }
584 
591  inline double drand(double mn,double mx){
592  double res;
593  drand48_r(&rd,&res);
594  return res*(mx-mn)+mn;
595  }
596 
602  virtual void initWeights(double initr) = 0;
603 
634  virtual double trainBatch(ExampleSet& ex,int start,int num,double eta) = 0;
635 
636 };
637 
638 
639 #endif /* __NET_HPP */
int getCount() const
get the number of examples
Definition: data.hpp:327
Contains integer enum for network types.
NetType
The different types of network - each has an associated integer for saving/loading file data...
Definition: netType.hpp:15
SGDParams & setInitRange(double range=-1)
fluent setter for initrange
Definition: net.hpp:263
int initrange
range of initial weights/biases [-n,n], or -1 for Bishop&#39;s rule.
Definition: net.hpp:260
virtual int getLayerCount() const =0
Get the number of layers.
double eta
Definition: net.hpp:186
SGDParams & crossValidationManual(int slices, int nperslice, int interval)
fluent setter for cross-validation parameters manually; consider using crossValidation instead ...
Definition: net.hpp:210
SGDParams & setSelectBestWithCV(bool v=true)
fluent setter for selectBestWithCV
Definition: net.hpp:239
A modulatory network architecture which uses two plain backprop networks, each of which is trained se...
Definition: obnet.hpp:18
Training parameters for trainSGD(). This structure holds the parameters for the trainSGD() method...
Definition: net.hpp:173
ShuffleMode
Shuffling mode for shuffle()
Definition: data.hpp:212
SGDParams & crossValidation(const ExampleSet &examples, double propCV, int cvCount, int cvSlices, bool cvShuf=true)
Set up the cross-validation parameters given the full training set, the proportion to be used for CV...
Definition: net.hpp:356
SGDParams(double _eta, int _iters)
Constructor which sets up defaults with no information about examples - cross-validation is not set u...
Definition: net.hpp:323
int getInputCount() const
get the number of inputs
Definition: net.hpp:75
virtual void save(double *buf) const =0
Serialize the data (not including any network type magic number or layer/node counts) to the given me...
A modulatory network architecture which uses a plain backprop network with an extra input to carry th...
Definition: hinet.hpp:17
double trainSGD(ExampleSet &examples, SGDParams &params)
Train using stochastic gradient descent. Note that cross-validation parameters are slightly different...
Definition: net.hpp:428
virtual void setH(double h)=0
Set the modulator level for subsequent runs and training of this network.
int getOutputCount() const
get the number of outputs
Definition: net.hpp:82
double sigmoidDiff(double x)
Definition: net.hpp:28
double * getOutputs(int example)
Get a pointer to the outputs for a given example, for reading or writing.
Definition: data.hpp:349
int cvInterval
how often to cross-validate given as the interval between CV events: 1 is every iteration, 2 is every other iteration and so on.
Definition: net.hpp:203
virtual double getH() const =0
get the modulator level
virtual double * getOutputs() const =0
Get the outputs after running.
double sigmoid(double x)
Definition: net.hpp:20
int nPerSlice
the number of example per cross-validation slice
Definition: net.hpp:197
void shuffle(drand48_data *rd, ShuffleMode mode, int nExamples=0)
Shuffle the example using a PRNG and a Fisher-Yates shuffle.
Definition: data.hpp:259
SGDParams & storeBest()
set up a "best net buffer" to store the best network found, to which the network will be set on compl...
Definition: net.hpp:394
bool storeBestNet
true if we should store the best net data
Definition: net.hpp:289
Contains formats for example data.
virtual void setInputs(double *d)=0
Set the inputs to the network before running or training.
double * bestNetBuffer
a buffer of at least getDataSize() bytes for the best network. If NULL, the best network is not saved...
Definition: net.hpp:284
virtual int getLayerSize(int n) const =0
Get the number of nodes in a given layer.
Net(NetType tp)
Constructor - protected because others inherit it and it&#39;s not used directly.
Definition: net.hpp:580
Shuffle blocks of numHLevels examples, rather than single examples. This is intended for cases where ...
Definition: data.hpp:228
SGDParams & setCVShuffle(bool v=true)
fluent setter for cvShuffle
Definition: net.hpp:252
virtual ~Net()
virtual destructor which does nothing
Definition: net.hpp:47
SGDParams & setSeed(long v)
fluent setter for seed
Definition: net.hpp:275
bool selectBestWithCV
if true, use the minimum CV error to find the best net, otherwise use the training error...
Definition: net.hpp:236
double * run(double *in)
Run the network on some data.
Definition: net.hpp:107
int getOutputCount() const
get the number of outputs in all examples
Definition: data.hpp:319
virtual void initWeights(double initr)=0
initialise weights to random values
virtual double trainBatch(ExampleSet &ex, int start, int num, double eta)=0
Train a network for batch (or mini-batch) (or single example).
virtual void update()=0
Run a single update of the network.
long seed
seed for random number generator used to initialise weights and also perform shuffling ...
Definition: net.hpp:272
virtual int getDataSize() const =0
Get the length of the serialised data block for this network.
virtual void load(double *buf)=0
Given that the pointer points to a data block of the correct size for the current network...
int iterations
number of iterations to run: an iteration is the presentation of a single example, NOT an epoch (or occasionally pair-presentation) as is the case in the thesis when discussing the modulatory network types.
Definition: net.hpp:181
int nSlices
The number of cross-validation slices to use.
Definition: net.hpp:192
NetType type
type of the network, used for load/save
Definition: net.hpp:49
drand48_data rd
PRNG data (thread safe)
Definition: net.hpp:50
SGDParams & setShuffle(ExampleSet::ShuffleMode m)
fluent setter for preserveHAlternation
Definition: net.hpp:224
bool cvShuffle
if true, shuffle the entire CV data set when all slices have been done so that the cross-validation h...
Definition: net.hpp:249
SGDParams(double _eta, const ExampleSet &examples, int _iters)
Definition: net.hpp:332
void setSeed(long seed)
Set this network&#39;s random number generator, which is used for weight initialisation done at the start...
Definition: net.hpp:57
double getH(int example) const
Get the h (modulator) for a given example.
Definition: data.hpp:359
~SGDParams()
Destructor.
Definition: net.hpp:340
double test(ExampleSet &examples, int start=0, int num=-1)
Test a network. Runs the network over a set of examples and returns the mean MSE for all outputs whe...
Definition: net.hpp:142
The abstract network type upon which all others are based. It&#39;s not pure virtual, in that it encapsul...
Definition: net.hpp:39
double * getInputs(int example)
Get a pointer to the inputs for a given example, for reading or writing.
Definition: data.hpp:338
double drand(double mn, double mx)
get a random number using this net&#39;s PRNG data
Definition: net.hpp:591
A set of example data. Each datum consists of hormone (i.e. modulator value), inputs and outputs...
Definition: data.hpp:57
ExampleSet::ShuffleMode shuffleMode
The shuffle mode to use - see the ExampleSet::ShuffleMode enum for details.
Definition: net.hpp:221