Scippy

SCIP

Solving Constraint Integer Programs

bandit_ucb.c
Go to the documentation of this file.
1 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2 /* */
3 /* This file is part of the program and library */
4 /* SCIP --- Solving Constraint Integer Programs */
5 /* */
6 /* Copyright (C) 2002-2019 Konrad-Zuse-Zentrum */
7 /* fuer Informationstechnik Berlin */
8 /* */
9 /* SCIP is distributed under the terms of the ZIB Academic License. */
10 /* */
11 /* You should have received a copy of the ZIB Academic License */
12 /* along with SCIP; see the file COPYING. If not visit scip.zib.de. */
13 /* */
14 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
15 
16 /**@file bandit_ucb.c
17  * @brief methods for UCB bandit selection
18  * @author Gregor Hendel
19  */
20 
21 /*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
22 
23 #include "scip/bandit.h"
24 #include "scip/bandit_ucb.h"
25 #include "scip/pub_bandit.h"
26 #include "scip/pub_message.h"
27 #include "scip/pub_misc.h"
28 #include "scip/pub_misc_sort.h"
29 #include "scip/scip_bandit.h"
30 #include "scip/scip_mem.h"
31 #include "scip/scip_randnumgen.h"
32 
33 
34 #define BANDIT_NAME "ucb"
35 #define NUMEPS 1e-6
36 
37 /*
38  * Data structures
39  */
40 
41 /** implementation specific data of UCB bandit algorithm */
42 struct SCIP_BanditData
43 {
44  int nselections; /**< counter for the number of selections */
45  int* counter; /**< array of counters how often every action has been chosen */
46  int* startperm; /**< indices for starting permutation */
47  SCIP_Real* meanscores; /**< array of average scores for the actions */
48  SCIP_Real alpha; /**< parameter to increase confidence width */
49 };
50 
51 
52 /*
53  * Local methods
54  */
55 
56 /** data reset method */
57 static
59  BMS_BUFMEM* bufmem, /**< buffer memory */
60  SCIP_BANDIT* ucb, /**< ucb bandit algorithm */
61  SCIP_BANDITDATA* banditdata, /**< UCB bandit data structure */
62  SCIP_Real* priorities, /**< priorities for start permutation, or NULL */
63  int nactions /**< number of actions */
64  )
65 {
66  int i;
67  SCIP_RANDNUMGEN* rng;
68 
69  assert(bufmem != NULL);
70  assert(ucb != NULL);
71  assert(nactions > 0);
72 
73  /* clear counters and scores */
74  BMSclearMemoryArray(banditdata->counter, nactions);
75  BMSclearMemoryArray(banditdata->meanscores, nactions);
76  banditdata->nselections = 0;
77 
78  rng = SCIPbanditGetRandnumgen(ucb);
79  assert(rng != NULL);
80 
81  /* initialize start permutation as identity */
82  for( i = 0; i < nactions; ++i )
83  banditdata->startperm[i] = i;
84 
85  /* prepare the start permutation in decreasing order of priority */
86  if( priorities != NULL )
87  {
88  SCIP_Real* prioritycopy;
89 
90  SCIP_ALLOC( BMSduplicateBufferMemoryArray(bufmem, &prioritycopy, priorities, nactions) );
91 
92  /* randomly wiggle priorities a little bit to make them unique */
93  for( i = 0; i < nactions; ++i )
94  prioritycopy[i] += SCIPrandomGetReal(rng, -NUMEPS, NUMEPS);
95 
96  SCIPsortDownRealInt(prioritycopy, banditdata->startperm, nactions);
97 
98  BMSfreeBufferMemoryArray(bufmem, &prioritycopy);
99  }
100  else
101  {
102  /* use a random start permutation */
103  SCIPrandomPermuteIntArray(rng, banditdata->startperm, 0, nactions);
104  }
105 
106  return SCIP_OKAY;
107 }
108 
109 
110 /*
111  * Callback methods of bandit algorithm
112  */
113 
114 /** callback to free bandit specific data structures */
115 SCIP_DECL_BANDITFREE(SCIPbanditFreeUcb)
116 { /*lint --e{715}*/
117  SCIP_BANDITDATA* banditdata;
118  int nactions;
119  assert(bandit != NULL);
120 
121  banditdata = SCIPbanditGetData(bandit);
122  assert(banditdata != NULL);
123  nactions = SCIPbanditGetNActions(bandit);
124 
125  BMSfreeBlockMemoryArray(blkmem, &banditdata->counter, nactions);
126  BMSfreeBlockMemoryArray(blkmem, &banditdata->startperm, nactions);
127  BMSfreeBlockMemoryArray(blkmem, &banditdata->meanscores, nactions);
128  BMSfreeBlockMemory(blkmem, &banditdata);
129 
130  SCIPbanditSetData(bandit, NULL);
131 
132  return SCIP_OKAY;
133 }
134 
135 /** selection callback for bandit selector */
136 SCIP_DECL_BANDITSELECT(SCIPbanditSelectUcb)
137 { /*lint --e{715}*/
138  SCIP_BANDITDATA* banditdata;
139  int nactions;
140  int* counter;
141 
142  assert(bandit != NULL);
143  assert(selection != NULL);
144 
145  banditdata = SCIPbanditGetData(bandit);
146  assert(banditdata != NULL);
147  nactions = SCIPbanditGetNActions(bandit);
148 
149  counter = banditdata->counter;
150  /* select the next uninitialized action from the start permutation */
151  if( banditdata->nselections < nactions )
152  {
153  *selection = banditdata->startperm[banditdata->nselections];
154  assert(counter[*selection] == 0);
155  }
156  else
157  {
158  /* select the action with the highest upper confidence bound */
159  SCIP_Real* meanscores;
160  SCIP_Real widthfactor;
161  SCIP_Real maxucb;
162  int i;
164  meanscores = banditdata->meanscores;
165 
166  assert(rng != NULL);
167  assert(meanscores != NULL);
168 
169  /* compute the confidence width factor that is common for all actions */
170  /* cppcheck-suppress unpreciseMathCall */
171  widthfactor = banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections);
172  widthfactor = sqrt(widthfactor);
173  maxucb = -1.0;
174 
175  /* loop over the actions and determine the maximum upper confidence bound.
176  * The upper confidence bound of an action is the sum of its mean score
177  * plus a confidence term that decreases with increasing number of observations of
178  * this action.
179  */
180  for( i = 0; i < nactions; ++i )
181  {
182  SCIP_Real uppercb;
183  SCIP_Real rootcount;
184  assert(counter[i] > 0);
185 
186  /* compute the upper confidence bound for action i */
187  uppercb = meanscores[i];
188  rootcount = sqrt((SCIP_Real)counter[i]);
189  uppercb += widthfactor / rootcount;
190  assert(uppercb > 0);
191 
192  /* update maximum, breaking ties uniformly at random */
193  if( EPSGT(uppercb, maxucb, NUMEPS) || (EPSEQ(uppercb, maxucb, NUMEPS) && SCIPrandomGetReal(rng, 0.0, 1.0) >= 0.5) )
194  {
195  maxucb = uppercb;
196  *selection = i;
197  }
198  }
199  }
200 
201  assert(*selection >= 0);
202  assert(*selection < nactions);
203 
204  return SCIP_OKAY;
205 }
206 
207 /** update callback for bandit algorithm */
208 SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateUcb)
209 { /*lint --e{715}*/
210  SCIP_BANDITDATA* banditdata;
211  SCIP_Real delta;
212 
213  assert(bandit != NULL);
214 
215  banditdata = SCIPbanditGetData(bandit);
216  assert(banditdata != NULL);
217  assert(selection >= 0);
218  assert(selection < SCIPbanditGetNActions(bandit));
219 
220  /* increase the mean by the incremental formula: A_n = A_n-1 + 1/n (a_n - A_n-1) */
221  delta = score - banditdata->meanscores[selection];
222  ++banditdata->counter[selection];
223  banditdata->meanscores[selection] += delta / (SCIP_Real)banditdata->counter[selection];
224 
225  banditdata->nselections++;
226 
227  return SCIP_OKAY;
228 }
229 
230 /** reset callback for bandit algorithm */
231 SCIP_DECL_BANDITRESET(SCIPbanditResetUcb)
232 { /*lint --e{715}*/
233  SCIP_BANDITDATA* banditdata;
234  int nactions;
235 
236  assert(bufmem != NULL);
237  assert(bandit != NULL);
238 
239  banditdata = SCIPbanditGetData(bandit);
240  assert(banditdata != NULL);
241  nactions = SCIPbanditGetNActions(bandit);
242 
243  /* call the data reset for the given priorities */
244  SCIP_CALL( dataReset(bufmem, bandit, banditdata, priorities, nactions) );
245 
246  return SCIP_OKAY;
247 }
248 
249 /*
250  * bandit algorithm specific interface methods
251  */
252 
253 /** returns the upper confidence bound of a selected action */
255  SCIP_BANDIT* ucb, /**< UCB bandit algorithm */
256  int action /**< index of the queried action */
257  )
258 {
259  SCIP_Real uppercb;
260  SCIP_BANDITDATA* banditdata;
261  int nactions;
262 
263  assert(ucb != NULL);
264  banditdata = SCIPbanditGetData(ucb);
265  nactions = SCIPbanditGetNActions(ucb);
266  assert(action < nactions);
267 
268  /* since only scores between 0 and 1 are allowed, 1.0 is a sure upper confidence bound */
269  if( banditdata->nselections < nactions )
270  return 1.0;
271 
272  /* the bandit algorithm must have picked every action once */
273  assert(banditdata->counter[action] > 0);
274  uppercb = banditdata->meanscores[action];
275 
276  /* cppcheck-suppress unpreciseMathCall */
277  uppercb += sqrt(banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections) / (SCIP_Real)banditdata->counter[action]);
278 
279  return uppercb;
280 }
281 
282 /** return start permutation of the UCB bandit algorithm */
284  SCIP_BANDIT* ucb /**< UCB bandit algorithm */
285  )
286 {
287  SCIP_BANDITDATA* banditdata = SCIPbanditGetData(ucb);
288 
289  assert(banditdata != NULL);
290 
291  return banditdata->startperm;
292 }
293 
294 /** internal method to create and reset UCB bandit algorithm */
296  BMS_BLKMEM* blkmem, /**< block memory */
297  BMS_BUFMEM* bufmem, /**< buffer memory */
298  SCIP_BANDITVTABLE* vtable, /**< virtual function table for UCB bandit algorithm */
299  SCIP_BANDIT** ucb, /**< pointer to store bandit algorithm */
300  SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
301  SCIP_Real alpha, /**< parameter to increase confidence width */
302  int nactions, /**< the positive number of actions for this bandit algorithm */
303  unsigned int initseed /**< initial random seed */
304  )
305 {
306  SCIP_BANDITDATA* banditdata;
307 
308  if( alpha < 0.0 )
309  {
310  SCIPerrorMessage("UCB requires nonnegative alpha parameter, have %f\n", alpha);
311  return SCIP_INVALIDDATA;
312  }
313 
314  SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
315  assert(banditdata != NULL);
316 
317  SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->counter, nactions) );
318  SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->startperm, nactions) );
319  SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->meanscores, nactions) );
320 
321  banditdata->alpha = alpha;
322 
323  SCIP_CALL( SCIPbanditCreate(ucb, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
324 
325  return SCIP_OKAY;
326 }
327 
328 /** create and reset UCB bandit algorithm */
330  SCIP* scip, /**< SCIP data structure */
331  SCIP_BANDIT** ucb, /**< pointer to store bandit algorithm */
332  SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
333  SCIP_Real alpha, /**< parameter to increase confidence width */
334  int nactions, /**< the positive number of actions for this bandit algorithm */
335  unsigned int initseed /**< initial random number seed */
336  )
337 {
338  SCIP_BANDITVTABLE* vtable;
339 
340  vtable = SCIPfindBanditvtable(scip, BANDIT_NAME);
341  if( vtable == NULL )
342  {
343  SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
344  return SCIP_INVALIDDATA;
345  }
346 
347  SCIP_CALL( SCIPbanditCreateUcb(SCIPblkmem(scip), SCIPbuffer(scip), vtable, ucb,
348  priorities, alpha, nactions, SCIPinitializeRandomSeed(scip, (int)(initseed % INT_MAX))) );
349 
350  return SCIP_OKAY;
351 }
352 
353 /** include virtual function table for UCB bandit algorithms */
355  SCIP* scip /**< SCIP data structure */
356  )
357 {
358  SCIP_BANDITVTABLE* vtable;
359 
361  SCIPbanditFreeUcb, SCIPbanditSelectUcb, SCIPbanditUpdateUcb, SCIPbanditResetUcb) );
362  assert(vtable != NULL);
363 
364  return SCIP_OKAY;
365 }
#define NULL
Definition: def.h:246
SCIP_RETCODE SCIPcreateBanditUcb(SCIP *scip, SCIP_BANDIT **ucb, SCIP_Real *priorities, SCIP_Real alpha, int nactions, unsigned int initseed)
Definition: bandit_ucb.c:329
public methods for memory management
#define EPSEQ(x, y, eps)
Definition: def.h:182
internal methods for bandit algorithms
enum SCIP_Retcode SCIP_RETCODE
Definition: type_retcode.h:53
void SCIPsortDownRealInt(SCIP_Real *realarray, int *intarray, int len)
#define BMSduplicateBufferMemoryArray(mem, ptr, source, num)
Definition: memory.h:715
SCIP_BANDITDATA * SCIPbanditGetData(SCIP_BANDIT *bandit)
Definition: bandit.c:180
BMS_BUFMEM * SCIPbuffer(SCIP *scip)
Definition: scip_mem.c:143
#define SCIPerrorMessage
Definition: pub_message.h:45
internal methods for UCB bandit algorithm
SCIPInterval sqrt(const SCIPInterval &x)
SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateUcb)
Definition: bandit_ucb.c:208
BMS_BLKMEM * SCIPblkmem(SCIP *scip)
Definition: scip_mem.c:128
SCIP_BANDITVTABLE * SCIPfindBanditvtable(SCIP *scip, const char *name)
Definition: scip_bandit.c:64
SCIP_DECL_BANDITRESET(SCIPbanditResetUcb)
Definition: bandit_ucb.c:231
#define SCIP_CALL(x)
Definition: def.h:358
void SCIPbanditSetData(SCIP_BANDIT *bandit, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:190
#define NUMEPS
Definition: bandit_ucb.c:35
#define BMSfreeBlockMemory(mem, ptr)
Definition: memory.h:453
public data structures and miscellaneous methods
SCIP_RETCODE SCIPincludeBanditvtableUcb(SCIP *scip)
Definition: bandit_ucb.c:354
#define BMSallocBlockMemoryArray(mem, ptr, num)
Definition: memory.h:442
SCIP_RETCODE SCIPincludeBanditvtable(SCIP *scip, SCIP_BANDITVTABLE **banditvtable, const char *name, SCIP_DECL_BANDITFREE((*banditfree)), SCIP_DECL_BANDITSELECT((*banditselect)), SCIP_DECL_BANDITUPDATE((*banditupdate)), SCIP_DECL_BANDITRESET((*banditreset)))
Definition: scip_bandit.c:32
SCIP_DECL_BANDITSELECT(SCIPbanditSelectUcb)
Definition: bandit_ucb.c:136
#define BMSfreeBlockMemoryArray(mem, ptr, num)
Definition: memory.h:455
void SCIPrandomPermuteIntArray(SCIP_RANDNUMGEN *randnumgen, int *array, int begin, int end)
Definition: misc.c:9649
int * SCIPgetStartPermutationUcb(SCIP_BANDIT *ucb)
Definition: bandit_ucb.c:283
public methods for bandit algorithms
#define BANDIT_NAME
Definition: bandit_ucb.c:34
SCIP_Real SCIPrandomGetReal(SCIP_RANDNUMGEN *randnumgen, SCIP_Real minrandval, SCIP_Real maxrandval)
Definition: misc.c:9630
public methods for bandit algorithms
struct SCIP_BanditData SCIP_BANDITDATA
Definition: type_bandit.h:47
methods for sorting joint arrays of various types
#define EPSGT(x, y, eps)
Definition: def.h:185
public methods for random numbers
public methods for message output
SCIP_RETCODE SCIPbanditCreateUcb(BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_BANDITVTABLE *vtable, SCIP_BANDIT **ucb, SCIP_Real *priorities, SCIP_Real alpha, int nactions, unsigned int initseed)
Definition: bandit_ucb.c:295
#define SCIP_Real
Definition: def.h:157
int SCIPbanditGetNActions(SCIP_BANDIT *bandit)
Definition: bandit.c:293
#define LOG1P(x)
Definition: def.h:229
static SCIP_RETCODE dataReset(BMS_BUFMEM *bufmem, SCIP_BANDIT *ucb, SCIP_BANDITDATA *banditdata, SCIP_Real *priorities, int nactions)
Definition: bandit_ucb.c:58
SCIP_RANDNUMGEN * SCIPbanditGetRandnumgen(SCIP_BANDIT *bandit)
Definition: bandit.c:283
#define BMSallocBlockMemory(mem, ptr)
Definition: memory.h:440
#define BMSclearMemoryArray(ptr, num)
Definition: memory.h:119
unsigned int SCIPinitializeRandomSeed(SCIP *scip, unsigned int initialseedvalue)
struct BMS_BlkMem BMS_BLKMEM
Definition: memory.h:426
#define SCIP_ALLOC(x)
Definition: def.h:369
SCIP_Real SCIPgetConfidenceBoundUcb(SCIP_BANDIT *ucb, int action)
Definition: bandit_ucb.c:254
#define BMSfreeBufferMemoryArray(mem, ptr)
Definition: memory.h:720
SCIP_RETCODE SCIPbanditCreate(SCIP_BANDIT **bandit, SCIP_BANDITVTABLE *banditvtable, BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_Real *priorities, int nactions, unsigned int initseed, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:32
SCIP_DECL_BANDITFREE(SCIPbanditFreeUcb)
Definition: bandit_ucb.c:115