UESMANN CPP  1.0
Reference implementation of UESMANN
/home/travis/build/jimfinnis/uesmanncpp/mnist.hpp
Go to the documentation of this file.
1 
7 #ifndef __MNIST_HPP
8 #define __MNIST_HPP
9 
10 
11 #include <stdint.h>
12 #include <stdio.h>
13 #include <stdlib.h>
14 #include <errno.h>
15 #include <string.h>
16 #include <arpa/inet.h>
17 #include <stdexcept>
18 
24 class MNIST {
25 public:
35  MNIST(const char *labelFile,const char *imgFile,int start=0,int len=0){
36  valid = false;
37  int rd;
38  FILE *a = fopen(labelFile,"rb");
39  if(!a){
40  printf("Error opening label file %s: %s\n",labelFile,strerror(errno));;
41  throw std::runtime_error("cannot open label file: " + std::string(labelFile));
42  exit(1);
43  }
44  uint32_t magic;
45  rd=fread(&magic,sizeof(uint32_t),1,a);
46  magic = htonl(magic);
47  if(magic!=2049){
48  printf("incorrect magic number in label file %s: %x\n",labelFile,magic);
49  throw std::runtime_error("bad magic number in label file");
50  }
51  rd=fread(&ct,sizeof(uint32_t),1,a);
52  ct = htonl(ct);
53  if(ct>100000){
54  printf("unfeasibly large count in label file %s: %x\n",labelFile,ct);
55  throw std::runtime_error("bad count in label file");
56  }
57 
58  if(!len)len=ct;
59  if(start+len>ct){
60  printf("specified range [%d-%d], only %d in file %s\n",start,start+len,ct,labelFile);
61  throw std::runtime_error("bad range in label file");
62  }
63  fseek(a,start*sizeof(uint8_t),SEEK_CUR); // skip some
64 
65  labels = new uint8_t[len];
66  rd = fread(labels,sizeof(uint8_t),len,a);
67  if(rd!=len){
68  printf("not enough items in label file %s: %d\n",labelFile,rd);
69  throw std::runtime_error("not enough elements in label file");
70  }
71  fclose(a);
72 
73 
74  a = fopen(imgFile,"rb");
75  if(!a){
76  printf("Error opening image file %s: %s\n",imgFile,strerror(errno));
77  throw std::runtime_error("cannot open image file: " + std::string(imgFile));
78  }
79  rd=fread(&magic,sizeof(uint32_t),1,a);
80  magic=htonl(magic);
81  if(magic!=2051){
82  printf("incorrect magic number in image file %s: %d\n",imgFile,magic);
83  throw std::runtime_error("bad magic in image file");
84  }
85  uint32_t ct2;
86  rd=fread(&ct2,sizeof(uint32_t),1,a);
87  ct2=htonl(ct2);
88  if(ct2!=ct){
89  printf("image file count does not agree with label file count:\n"
90  "%s:%d != %s:%d\n",
91  imgFile,ct2,labelFile,ct);
92  throw std::runtime_error("bad count in image file");
93  }
94 
95  rd=fread(&rows,sizeof(uint32_t),1,a);
96  rows = htonl(rows);
97  rd=fread(&cols,sizeof(uint32_t),1,a);
98  cols = htonl(cols);
99  if(rows > 128 || cols > 128){
100  printf("Bad dimensions in image file %s: %dx%d\n",imgFile,rows,cols);
101  throw std::runtime_error("bad dimensions in image file");
102  }
103 
104  fseek(a,start*sizeof(uint8_t)*rows*cols,SEEK_CUR); // skip some
105  imgs = new uint8_t[len*rows*cols];
106  rd = fread(imgs,sizeof(uint8_t),rows*cols*len,a);
107  if(rd!=len*rows*cols){
108  printf("wrong amount of pixels in image file %s: %d\n",imgFile,rd);
109  throw std::runtime_error("bad filesize in image file");
110  }
111  fclose(a);
112  ct=len;
113 
114  // get the max label
115  maxLabel=0;
116  for(int i=0;i<ct;i++){
117  if(getLabel(i)>maxLabel)
118  maxLabel=getLabel(i);
119  }
120 
121  }
122 
123 
128  delete [] labels;
129  delete [] imgs;
130  }
131 
135  int getCount() const {
136  return ct;
137  }
138 
142  int r() const {
143  return rows;
144  }
145 
150  int c() const {
151  return cols;
152  }
153 
158  uint8_t getLabel(int n) const {
159  return labels[n];
160  }
161 
166  uint8_t getMaxLabel() const {
167  return maxLabel;
168  }
169 
175  uint8_t *getImg(int n) const {
176  return imgs+rows*cols*n;
177  }
178 
184  uint8_t getPix(int n,int x,int y) const {
185  int idx = x+y*cols;
186  return getImg(n)[idx];
187  }
188 
192  void dump(int i) const {
193  if(i>=getCount())
194  printf("Out of range\n");
195  else {
196  printf("Label: %d\n",getLabel(i));
197  uint8_t *d = getImg(i);
198  for(int x=0;x<r();x++){
199  for(int y=0;y<c();y++){
200  uint8_t qq = *d++ / 25;
201  if(qq>9)qq=9;
202  putchar(qq ? qq+'0': '.');
203  }
204  putchar('\n');
205  }
206  }
207  }
208 
209 private:
213  bool valid;
214 
218  uint32_t rows;
219 
223  uint32_t cols;
224 
228  uint32_t ct;
229 
233  uint8_t maxLabel;
234 
239  uint8_t *labels;
240 
244  uint8_t *imgs;
245 };
246 
247 
248 #endif /* __MNIST_HPP */
uint8_t getMaxLabel() const
get the maximum label value (0 to 9 in the original data but different in other tests) ...
Definition: mnist.hpp:166
int r() const
returns the number of rows in each image
Definition: mnist.hpp:142
void dump(int i) const
dump the image data to standard out
Definition: mnist.hpp:192
~MNIST()
Destructor.
Definition: mnist.hpp:127
This class encapsulates and loads data in the standard MNIST format. The data resides in two files...
Definition: mnist.hpp:24
uint8_t getLabel(int n) const
get the label for a given example
Definition: mnist.hpp:158
uint8_t getPix(int n, int x, int y) const
get a pixel for a given example
Definition: mnist.hpp:184
int c() const
returns the number of columns in each image
Definition: mnist.hpp:150
MNIST(const char *labelFile, const char *imgFile, int start=0, int len=0)
constructor which loads the data from the given file, and can load only part of the data in a file...
Definition: mnist.hpp:35
uint8_t * getImg(int n) const
get the bitmap for a given example
Definition: mnist.hpp:175
int getCount() const
returns the number of examples
Definition: mnist.hpp:135