Scippy

SCIP

Solving Constraint Integer Programs

bandit_exp3.c
Go to the documentation of this file.
1 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2 /* */
3 /* This file is part of the program and library */
4 /* SCIP --- Solving Constraint Integer Programs */
5 /* */
6 /* Copyright (c) 2002-2024 Zuse Institute Berlin (ZIB) */
7 /* */
8 /* Licensed under the Apache License, Version 2.0 (the "License"); */
9 /* you may not use this file except in compliance with the License. */
10 /* You may obtain a copy of the License at */
11 /* */
12 /* http://www.apache.org/licenses/LICENSE-2.0 */
13 /* */
14 /* Unless required by applicable law or agreed to in writing, software */
15 /* distributed under the License is distributed on an "AS IS" BASIS, */
16 /* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */
17 /* See the License for the specific language governing permissions and */
18 /* limitations under the License. */
19 /* */
20 /* You should have received a copy of the Apache-2.0 license */
21 /* along with SCIP; see the file LICENSE. If not visit scipopt.org. */
22 /* */
23 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
24 
25 /**@file bandit_exp3.c
26  * @ingroup OTHER_CFILES
27  * @brief methods for Exp.3 bandit selection
28  * @author Gregor Hendel
29  */
30 
31 /*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
32 
33 #include "scip/bandit.h"
34 #include "scip/bandit_exp3.h"
35 #include "scip/pub_bandit.h"
36 #include "scip/pub_message.h"
37 #include "scip/pub_misc.h"
38 #include "scip/scip_bandit.h"
39 #include "scip/scip_mem.h"
40 #include "scip/scip_randnumgen.h"
41 
42 #define BANDIT_NAME "exp3"
43 #define NUMTOL 1e-6
44 
45 /*
46  * Data structures
47  */
48 
49 /** implementation specific data of Exp.3 bandit algorithm */
50 struct SCIP_BanditData
51 {
52  SCIP_Real* weights; /**< exponential weight for each arm */
53  SCIP_Real weightsum; /**< the sum of all weights */
54  SCIP_Real gamma; /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
55  SCIP_Real beta; /**< gain offset between 0 and 1 at every observation */
56 };
57 
58 /*
59  * Local methods
60  */
61 
62 /*
63  * Callback methods of bandit algorithm
64  */
65 
66 /** callback to free bandit specific data structures */
67 SCIP_DECL_BANDITFREE(SCIPbanditFreeExp3)
68 { /*lint --e{715}*/
69  SCIP_BANDITDATA* banditdata;
70  int nactions;
71  assert(bandit != NULL);
72 
73  banditdata = SCIPbanditGetData(bandit);
74  assert(banditdata != NULL);
75  nactions = SCIPbanditGetNActions(bandit);
76 
77  BMSfreeBlockMemoryArray(blkmem, &banditdata->weights, nactions);
78 
79  BMSfreeBlockMemory(blkmem, &banditdata);
80 
81  SCIPbanditSetData(bandit, NULL);
82 
83  return SCIP_OKAY;
84 }
85 
86 /** selection callback for bandit selector */
87 SCIP_DECL_BANDITSELECT(SCIPbanditSelectExp3)
88 { /*lint --e{715}*/
89  SCIP_BANDITDATA* banditdata;
90  SCIP_RANDNUMGEN* rng;
91  SCIP_Real randnr;
92  SCIP_Real psum;
93  SCIP_Real gammaoverk;
94  SCIP_Real oneminusgamma;
95  SCIP_Real* weights;
96  SCIP_Real weightsum;
97  int i;
98  int nactions;
99 
100  assert(bandit != NULL);
101  assert(selection != NULL);
102 
103  banditdata = SCIPbanditGetData(bandit);
104  assert(banditdata != NULL);
105  rng = SCIPbanditGetRandnumgen(bandit);
106  assert(rng != NULL);
107  nactions = SCIPbanditGetNActions(bandit);
108 
109  /* draw a random number between 0 and 1 */
110  randnr = SCIPrandomGetReal(rng, 0.0, 1.0);
111 
112  /* initialize some local variables to speed up probability computations */
113  oneminusgamma = 1 - banditdata->gamma;
114  gammaoverk = banditdata->gamma / (SCIP_Real)nactions;
115  weightsum = banditdata->weightsum;
116  weights = banditdata->weights;
117  psum = 0.0;
118 
119  /* loop over probability distribution until rand is reached
120  * the loop terminates without looking at the last action,
121  * which is then selected automatically if the target probability
122  * is not reached earlier
123  */
124  for( i = 0; i < nactions - 1; ++i )
125  {
126  SCIP_Real prob;
127 
128  /* compute the probability for arm i as convex kombination of a uniform distribution and a weighted distribution */
129  prob = oneminusgamma * weights[i] / weightsum + gammaoverk;
130  psum += prob;
131 
132  /* break and select element if target probability is reached */
133  if( randnr <= psum )
134  break;
135  }
136 
137  /* select element i, which is the last action in case that the break statement hasn't been reached */
138  *selection = i;
139 
140  return SCIP_OKAY;
141 }
142 
143 /** update callback for bandit algorithm */
144 SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateExp3)
145 { /*lint --e{715}*/
146  SCIP_BANDITDATA* banditdata;
147  SCIP_Real eta;
148  SCIP_Real gainestim;
149  SCIP_Real beta;
150  SCIP_Real weightsum;
151  SCIP_Real newweightsum;
152  SCIP_Real* weights;
153  SCIP_Real oneminusgamma;
154  SCIP_Real gammaoverk;
155  int nactions;
156 
157  assert(bandit != NULL);
158 
159  banditdata = SCIPbanditGetData(bandit);
160  assert(banditdata != NULL);
161  nactions = SCIPbanditGetNActions(bandit);
162 
163  assert(selection >= 0);
164  assert(selection < nactions);
165 
166  /* the learning rate eta */
167  eta = 1.0 / (SCIP_Real)nactions;
168 
169  beta = banditdata->beta;
170  oneminusgamma = 1.0 - banditdata->gamma;
171  gammaoverk = banditdata->gamma * eta;
172  weights = banditdata->weights;
173  weightsum = banditdata->weightsum;
174  newweightsum = weightsum;
175 
176  /* if beta is zero, only the observation for the current arm needs an update */
177  if( EPSZ(beta, NUMTOL) )
178  {
179  SCIP_Real probai;
180  probai = oneminusgamma * weights[selection] / weightsum + gammaoverk;
181 
182  assert(probai > 0.0);
183 
184  gainestim = score / probai;
185  newweightsum -= weights[selection];
186  weights[selection] *= exp(eta * gainestim);
187  newweightsum += weights[selection];
188  }
189  else
190  {
191  int j;
192  newweightsum = 0.0;
193 
194  /* loop over all items and update their weights based on the influence of the beta parameter */
195  for( j = 0; j < nactions; ++j )
196  {
197  SCIP_Real probaj;
198  probaj = oneminusgamma * weights[j] / weightsum + gammaoverk;
199 
200  assert(probaj > 0.0);
201 
202  /* consider the score only for the chosen arm i, use constant beta offset otherwise */
203  if( j == selection )
204  gainestim = (score + beta) / probaj;
205  else
206  gainestim = beta / probaj;
207 
208  weights[j] *= exp(eta * gainestim);
209  newweightsum += weights[j];
210  }
211  }
212 
213  banditdata->weightsum = newweightsum;
214 
215  return SCIP_OKAY;
216 }
217 
218 /** reset callback for bandit algorithm */
219 SCIP_DECL_BANDITRESET(SCIPbanditResetExp3)
220 { /*lint --e{715}*/
221  SCIP_BANDITDATA* banditdata;
222  SCIP_Real* weights;
223  int nactions;
224  int i;
225 
226  assert(bandit != NULL);
227 
228  banditdata = SCIPbanditGetData(bandit);
229  assert(banditdata != NULL);
230  nactions = SCIPbanditGetNActions(bandit);
231  weights = banditdata->weights;
232 
233  assert(nactions > 0);
234 
235  banditdata->weightsum = (1.0 + NUMTOL) * (SCIP_Real)nactions;
236 
237  /* in case of priorities, weights are normalized to sum up to nactions */
238  if( priorities != NULL )
239  {
240  SCIP_Real normalization;
241  SCIP_Real priosum;
242  priosum = 0.0;
243 
244  /* compute sum of priorities */
245  for( i = 0; i < nactions; ++i )
246  {
247  assert(priorities[i] >= 0);
248  priosum += priorities[i];
249  }
250 
251  /* if there are positive priorities, normalize the weights */
252  if( priosum > 0.0 )
253  {
254  normalization = nactions / priosum;
255  for( i = 0; i < nactions; ++i )
256  weights[i] = (priorities[i] * normalization) + NUMTOL;
257  }
258  else
259  {
260  /* use uniform distribution in case of all priorities being 0.0 */
261  for( i = 0; i < nactions; ++i )
262  weights[i] = 1.0 + NUMTOL;
263  }
264  }
265  else
266  {
267  /* use uniform distribution in case of unspecified priorities */
268  for( i = 0; i < nactions; ++i )
269  weights[i] = 1.0 + NUMTOL;
270  }
271 
272  return SCIP_OKAY;
273 }
274 
275 
276 /*
277  * bandit algorithm specific interface methods
278  */
279 
280 /** direct bandit creation method for the core where no SCIP pointer is available */
282  BMS_BLKMEM* blkmem, /**< block memory data structure */
283  BMS_BUFMEM* bufmem, /**< buffer memory */
284  SCIP_BANDITVTABLE* vtable, /**< virtual function table for callback functions of Exp.3 */
285  SCIP_BANDIT** exp3, /**< pointer to store bandit algorithm */
286  SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
287  SCIP_Real gammaparam, /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
288  SCIP_Real beta, /**< gain offset between 0 and 1 at every observation */
289  int nactions, /**< the positive number of actions for this bandit algorithm */
290  unsigned int initseed /**< initial random seed */
291  )
292 {
293  SCIP_BANDITDATA* banditdata;
294 
295  SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
296  assert(banditdata != NULL);
297 
298  banditdata->gamma = gammaparam;
299  banditdata->beta = beta;
300  assert(gammaparam >= 0 && gammaparam <= 1);
301  assert(beta >= 0 && beta <= 1);
302 
303  SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->weights, nactions) );
304 
305  SCIP_CALL( SCIPbanditCreate(exp3, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
306 
307  return SCIP_OKAY;
308 }
309 
310 /** creates and resets an Exp.3 bandit algorithm using \p scip pointer */
312  SCIP* scip, /**< SCIP data structure */
313  SCIP_BANDIT** exp3, /**< pointer to store bandit algorithm */
314  SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
315  SCIP_Real gammaparam, /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
316  SCIP_Real beta, /**< gain offset between 0 and 1 at every observation */
317  int nactions, /**< the positive number of actions for this bandit algorithm */
318  unsigned int initseed /**< initial seed for random number generation */
319  )
320 {
321  SCIP_BANDITVTABLE* vtable;
322 
323  vtable = SCIPfindBanditvtable(scip, BANDIT_NAME);
324  if( vtable == NULL )
325  {
326  SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
327  return SCIP_INVALIDDATA;
328  }
329 
330  SCIP_CALL( SCIPbanditCreateExp3(SCIPblkmem(scip), SCIPbuffer(scip), vtable, exp3,
331  priorities, gammaparam, beta, nactions, SCIPinitializeRandomSeed(scip, initseed)) );
332 
333  return SCIP_OKAY;
334 }
335 
336 /** set gamma parameter of Exp.3 bandit algorithm to increase weight of uniform distribution */
338  SCIP_BANDIT* exp3, /**< bandit algorithm */
339  SCIP_Real gammaparam /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
340  )
341 {
342  SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3);
343 
344  assert(gammaparam >= 0 && gammaparam <= 1);
345 
346  banditdata->gamma = gammaparam;
347 }
348 
349 /** set beta parameter of Exp.3 bandit algorithm to increase gain offset for actions that were not played */
351  SCIP_BANDIT* exp3, /**< bandit algorithm */
352  SCIP_Real beta /**< gain offset between 0 and 1 at every observation */
353  )
354 {
355  SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3);
356 
357  assert(beta >= 0 && beta <= 1);
358 
359  banditdata->beta = beta;
360 }
361 
362 /** returns probability to play an action */
364  SCIP_BANDIT* exp3, /**< bandit algorithm */
365  int action /**< index of the requested action */
366  )
367 {
368  SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3);
369 
370  assert(banditdata->weightsum > 0.0);
371  assert(SCIPbanditGetNActions(exp3) > 0);
372 
373  return (1.0 - banditdata->gamma) * banditdata->weights[action] / banditdata->weightsum + banditdata->gamma / (SCIP_Real)SCIPbanditGetNActions(exp3);
374 }
375 
376 /** include virtual function table for Exp.3 bandit algorithms */
378  SCIP* scip /**< SCIP data structure */
379  )
380 {
381  SCIP_BANDITVTABLE* vtable;
382 
384  SCIPbanditFreeExp3, SCIPbanditSelectExp3, SCIPbanditUpdateExp3, SCIPbanditResetExp3) );
385  assert(vtable != NULL);
386 
387  return SCIP_OKAY;
388 }
#define NULL
Definition: def.h:267
SCIP_RETCODE SCIPcreateBanditExp3(SCIP *scip, SCIP_BANDIT **exp3, SCIP_Real *priorities, SCIP_Real gammaparam, SCIP_Real beta, int nactions, unsigned int initseed)
Definition: bandit_exp3.c:311
public methods for memory management
SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateExp3)
Definition: bandit_exp3.c:144
void SCIPsetBetaExp3(SCIP_BANDIT *exp3, SCIP_Real beta)
Definition: bandit_exp3.c:350
internal methods for bandit algorithms
enum SCIP_Retcode SCIP_RETCODE
Definition: type_retcode.h:63
#define BANDIT_NAME
Definition: bandit_exp3.c:42
SCIP_RETCODE SCIPbanditCreateExp3(BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_BANDITVTABLE *vtable, SCIP_BANDIT **exp3, SCIP_Real *priorities, SCIP_Real gammaparam, SCIP_Real beta, int nactions, unsigned int initseed)
Definition: bandit_exp3.c:281
SCIP_BANDITDATA * SCIPbanditGetData(SCIP_BANDIT *bandit)
Definition: bandit.c:190
SCIP_DECL_BANDITSELECT(SCIPbanditSelectExp3)
Definition: bandit_exp3.c:87
BMS_BUFMEM * SCIPbuffer(SCIP *scip)
Definition: scip_mem.c:72
#define SCIPerrorMessage
Definition: pub_message.h:64
BMS_BLKMEM * SCIPblkmem(SCIP *scip)
Definition: scip_mem.c:57
SCIP_BANDITVTABLE * SCIPfindBanditvtable(SCIP *scip, const char *name)
Definition: scip_bandit.c:80
SCIP_DECL_BANDITRESET(SCIPbanditResetExp3)
Definition: bandit_exp3.c:219
#define SCIP_CALL(x)
Definition: def.h:380
SCIP_DECL_BANDITFREE(SCIPbanditFreeExp3)
Definition: bandit_exp3.c:67
void SCIPbanditSetData(SCIP_BANDIT *bandit, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:200
#define BMSfreeBlockMemory(mem, ptr)
Definition: memory.h:465
SCIP_RETCODE SCIPincludeBanditvtableExp3(SCIP *scip)
Definition: bandit_exp3.c:377
public data structures and miscellaneous methods
SCIP_Real SCIPgetProbabilityExp3(SCIP_BANDIT *exp3, int action)
Definition: bandit_exp3.c:363
#define BMSallocBlockMemoryArray(mem, ptr, num)
Definition: memory.h:454
SCIP_RETCODE SCIPincludeBanditvtable(SCIP *scip, SCIP_BANDITVTABLE **banditvtable, const char *name, SCIP_DECL_BANDITFREE((*banditfree)), SCIP_DECL_BANDITSELECT((*banditselect)), SCIP_DECL_BANDITUPDATE((*banditupdate)), SCIP_DECL_BANDITRESET((*banditreset)))
Definition: scip_bandit.c:48
#define BMSfreeBlockMemoryArray(mem, ptr, num)
Definition: memory.h:467
public methods for bandit algorithms
SCIP_Real SCIPrandomGetReal(SCIP_RANDNUMGEN *randnumgen, SCIP_Real minrandval, SCIP_Real maxrandval)
Definition: misc.c:10130
public methods for bandit algorithms
struct SCIP_BanditData SCIP_BANDITDATA
Definition: type_bandit.h:56
#define NUMTOL
Definition: bandit_exp3.c:43
public methods for random numbers
internal methods for Exp.3 bandit algorithm
void SCIPsetGammaExp3(SCIP_BANDIT *exp3, SCIP_Real gammaparam)
Definition: bandit_exp3.c:337
public methods for message output
#define SCIP_Real
Definition: def.h:173
int SCIPbanditGetNActions(SCIP_BANDIT *bandit)
Definition: bandit.c:303
SCIP_RANDNUMGEN * SCIPbanditGetRandnumgen(SCIP_BANDIT *bandit)
Definition: bandit.c:293
#define BMSallocBlockMemory(mem, ptr)
Definition: memory.h:451
unsigned int SCIPinitializeRandomSeed(SCIP *scip, unsigned int initialseedvalue)
struct BMS_BlkMem BMS_BLKMEM
Definition: memory.h:437
#define SCIP_ALLOC(x)
Definition: def.h:391
#define EPSZ(x, eps)
Definition: def.h:203
SCIP_RETCODE SCIPbanditCreate(SCIP_BANDIT **bandit, SCIP_BANDITVTABLE *banditvtable, BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_Real *priorities, int nactions, unsigned int initseed, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:42