UESMANN CPP  1.0
Reference implementation of UESMANN
/home/travis/build/jimfinnis/uesmanncpp/testSaveLoad.cpp
Go to the documentation of this file.
1 
9 #include <iostream>
10 
11 #include <boost/test/unit_test.hpp>
12 
13 #include "test.hpp"
14 
21 BOOST_AUTO_TEST_SUITE(saveload)
22 
23 
25  // generate a new network
26  int layers[3];
27  layers[0]=4;
28  layers[1]=3;
29  layers[2]=2;
30  Net *n = NetFactory::makeNet(tp,3,layers);
31 
32  // generate a toy example. Doesn't matter what it is.
33  ExampleSet e(1,4,2,1);
34  double *p = e.getInputs(0);
35  *p++=0;
36  *p++=2;
37  *p++=3;
38  *p=1;
39  p = e.getOutputs(0);
40  *p++=100;
41  *p=20;
42  e.setH(0,0);
43 
44  // train it a little.
45  Net::SGDParams parms(10,e,100);
46  n->trainSGD(e,parms);
47 
48  // save the net to memory
49  double *oldData = new double[n->getDataSize()];
50  n->save(oldData);
51 
52  // now save the net to disk
53  NetFactory::save("foo.net",n);
54 
55  // and load
56  Net *saved = NetFactory::load("foo.net");
57 
58  BOOST_REQUIRE(n->type == saved->type);
59  BOOST_REQUIRE(n->getDataSize() == saved->getDataSize());
60 
61  // save the newly loaded net params to memory
62  double *savedData = new double[saved->getDataSize()];
63  saved->save(savedData);
64 
65  // and compare params
66  for(int i=0;i<n->getDataSize();i++){
67  BOOST_REQUIRE(oldData[i]==savedData[i]);
68  }
69 
70  delete [] savedData;
71  delete [] oldData;
72  delete n;
73 }
74 
75 
80 BOOST_AUTO_TEST_CASE(saveloadplain) {
82 }
87 BOOST_AUTO_TEST_CASE(saveloadob) {
89 }
94 BOOST_AUTO_TEST_CASE(saveloadhin) {
96 }
101 BOOST_AUTO_TEST_CASE(saveloadues) {
103 }
104 
105 
111 BOOST_AUTO_TEST_SUITE_END()
112 
NetType
The different types of network - each has an associated integer for saving/loading file data...
Definition: netType.hpp:15
static Net * load(const char *fn)
Load a network of any type from a file - note, endianness not checked!
Definition: netFactory.hpp:61
output blending
void setH(int example, double h)
Set the h (modulator) for a given example.
Definition: data.hpp:378
Training parameters for trainSGD(). This structure holds the parameters for the trainSGD() method...
Definition: net.hpp:173
Useful stuff for testing.
virtual void save(double *buf) const =0
Serialize the data (not including any network type magic number or layer/node counts) to the given me...
double trainSGD(ExampleSet &examples, SGDParams &params)
Train using stochastic gradient descent. Note that cross-validation parameters are slightly different...
Definition: net.hpp:428
plain back-propagation
double * getOutputs(int example)
Get a pointer to the outputs for a given example, for reading or writing.
Definition: data.hpp:349
static void save(const char *fn, Net *n)
Save a net of any type to a file - note, endianness not checked!
Definition: netFactory.hpp:119
static Net * makeNet(NetType t, ExampleSet &e, int hnodes)
Construct a single hidden layer network of a given type which conforms to the example set...
Definition: netFactory.hpp:32
virtual int getDataSize() const =0
Get the length of the serialised data block for this network.
void testSaveLoad(NetType tp)
NetType type
type of the network, used for load/save
Definition: net.hpp:49
BOOST_AUTO_TEST_CASE(saveloadplain)
Test that saving and loading a plain network leaves the weights and biases unchanged.
h-as-input
The abstract network type upon which all others are based. It&#39;s not pure virtual, in that it encapsul...
Definition: net.hpp:39
double * getInputs(int example)
Get a pointer to the inputs for a given example, for reading or writing.
Definition: data.hpp:338
A set of example data. Each datum consists of hormone (i.e. modulator value), inputs and outputs...
Definition: data.hpp:57