Scippy

SCIP

Solving Constraint Integer Programs

bandit_ucb.c
Go to the documentation of this file.
1/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2/* */
3/* This file is part of the program and library */
4/* SCIP --- Solving Constraint Integer Programs */
5/* */
6/* Copyright (c) 2002-2025 Zuse Institute Berlin (ZIB) */
7/* */
8/* Licensed under the Apache License, Version 2.0 (the "License"); */
9/* you may not use this file except in compliance with the License. */
10/* You may obtain a copy of the License at */
11/* */
12/* http://www.apache.org/licenses/LICENSE-2.0 */
13/* */
14/* Unless required by applicable law or agreed to in writing, software */
15/* distributed under the License is distributed on an "AS IS" BASIS, */
16/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */
17/* See the License for the specific language governing permissions and */
18/* limitations under the License. */
19/* */
20/* You should have received a copy of the Apache-2.0 license */
21/* along with SCIP; see the file LICENSE. If not visit scipopt.org. */
22/* */
23/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
24
25/**@file bandit_ucb.c
26 * @ingroup OTHER_CFILES
27 * @brief methods for UCB bandit selection
28 * @author Gregor Hendel
29 */
30
31/*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
32
33#include "scip/bandit.h"
34#include "scip/bandit_ucb.h"
35#include "scip/pub_bandit.h"
36#include "scip/pub_message.h"
37#include "scip/pub_misc.h"
38#include "scip/pub_misc_sort.h"
39#include "scip/scip_bandit.h"
40#include "scip/scip_mem.h"
42
43
44#define BANDIT_NAME "ucb"
45#define NUMEPS 1e-6
46
47/*
48 * Data structures
49 */
50
51/** implementation specific data of UCB bandit algorithm */
52struct SCIP_BanditData
53{
54 int nselections; /**< counter for the number of selections */
55 int* counter; /**< array of counters how often every action has been chosen */
56 int* startperm; /**< indices for starting permutation */
57 SCIP_Real* meanscores; /**< array of average scores for the actions */
58 SCIP_Real alpha; /**< parameter to increase confidence width */
59};
60
61
62/*
63 * Local methods
64 */
65
66/** data reset method */
67static
69 BMS_BUFMEM* bufmem, /**< buffer memory */
70 SCIP_BANDIT* ucb, /**< ucb bandit algorithm */
71 SCIP_BANDITDATA* banditdata, /**< UCB bandit data structure */
72 SCIP_Real* priorities, /**< priorities for start permutation, or NULL */
73 int nactions /**< number of actions */
74 )
75{
76 int i;
77 SCIP_RANDNUMGEN* rng;
78
79 assert(bufmem != NULL);
80 assert(ucb != NULL);
81 assert(nactions > 0);
82
83 /* clear counters and scores */
84 BMSclearMemoryArray(banditdata->counter, nactions);
85 BMSclearMemoryArray(banditdata->meanscores, nactions);
86 banditdata->nselections = 0;
87
88 rng = SCIPbanditGetRandnumgen(ucb);
89 assert(rng != NULL);
90
91 /* initialize start permutation as identity */
92 for( i = 0; i < nactions; ++i )
93 banditdata->startperm[i] = i;
94
95 /* prepare the start permutation in decreasing order of priority */
96 if( priorities != NULL )
97 {
98 SCIP_Real* prioritycopy;
99
100 SCIP_ALLOC( BMSduplicateBufferMemoryArray(bufmem, &prioritycopy, priorities, nactions) );
101
102 /* randomly wiggle priorities a little bit to make them unique */
103 for( i = 0; i < nactions; ++i )
104 prioritycopy[i] += SCIPrandomGetReal(rng, -NUMEPS, NUMEPS);
105
106 SCIPsortDownRealInt(prioritycopy, banditdata->startperm, nactions);
107
108 BMSfreeBufferMemoryArray(bufmem, &prioritycopy);
109 }
110 else
111 {
112 /* use a random start permutation */
113 SCIPrandomPermuteIntArray(rng, banditdata->startperm, 0, nactions);
114 }
115
116 return SCIP_OKAY;
117}
118
119
120/*
121 * Callback methods of bandit algorithm
122 */
123
124/** callback to free bandit specific data structures */
125SCIP_DECL_BANDITFREE(SCIPbanditFreeUcb)
126{ /*lint --e{715}*/
127 SCIP_BANDITDATA* banditdata;
128 int nactions;
129 assert(bandit != NULL);
130
131 banditdata = SCIPbanditGetData(bandit);
132 assert(banditdata != NULL);
133 nactions = SCIPbanditGetNActions(bandit);
134
135 BMSfreeBlockMemoryArray(blkmem, &banditdata->counter, nactions);
136 BMSfreeBlockMemoryArray(blkmem, &banditdata->startperm, nactions);
137 BMSfreeBlockMemoryArray(blkmem, &banditdata->meanscores, nactions);
138 BMSfreeBlockMemory(blkmem, &banditdata);
139
140 SCIPbanditSetData(bandit, NULL);
141
142 return SCIP_OKAY;
143}
144
145/** selection callback for bandit selector */
146SCIP_DECL_BANDITSELECT(SCIPbanditSelectUcb)
147{ /*lint --e{715}*/
148 SCIP_BANDITDATA* banditdata;
149 int nactions;
150 int* counter;
151
152 assert(bandit != NULL);
153 assert(selection != NULL);
154
155 banditdata = SCIPbanditGetData(bandit);
156 assert(banditdata != NULL);
157 nactions = SCIPbanditGetNActions(bandit);
158
159 counter = banditdata->counter;
160 /* select the next uninitialized action from the start permutation */
161 if( banditdata->nselections < nactions )
162 {
163 *selection = banditdata->startperm[banditdata->nselections];
164 assert(counter[*selection] == 0);
165 }
166 else
167 {
168 /* select the action with the highest upper confidence bound */
169 SCIP_Real* meanscores;
170 SCIP_Real widthfactor;
171 SCIP_Real maxucb;
172 int i;
174 meanscores = banditdata->meanscores;
175
176 assert(rng != NULL);
177 assert(meanscores != NULL);
178
179 /* compute the confidence width factor that is common for all actions */
180 widthfactor = banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections);
181 widthfactor = sqrt(widthfactor);
182 maxucb = -1.0;
183
184 /* loop over the actions and determine the maximum upper confidence bound.
185 * The upper confidence bound of an action is the sum of its mean score
186 * plus a confidence term that decreases with increasing number of observations of
187 * this action.
188 */
189 for( i = 0; i < nactions; ++i )
190 {
191 SCIP_Real uppercb;
192 SCIP_Real rootcount;
193 assert(counter[i] > 0);
194
195 /* compute the upper confidence bound for action i */
196 uppercb = meanscores[i];
197 rootcount = sqrt((SCIP_Real)counter[i]);
198 uppercb += widthfactor / rootcount;
199 assert(uppercb > 0);
200
201 /* update maximum, breaking ties uniformly at random */
202 if( EPSGT(uppercb, maxucb, NUMEPS) || (EPSEQ(uppercb, maxucb, NUMEPS) && SCIPrandomGetReal(rng, 0.0, 1.0) >= 0.5) )
203 {
204 maxucb = uppercb;
205 *selection = i;
206 }
207 }
208 }
209
210 assert(*selection >= 0);
211 assert(*selection < nactions);
212
213 return SCIP_OKAY;
214}
215
216/** update callback for bandit algorithm */
217SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateUcb)
218{ /*lint --e{715}*/
219 SCIP_BANDITDATA* banditdata;
220 SCIP_Real delta;
221
222 assert(bandit != NULL);
223
224 banditdata = SCIPbanditGetData(bandit);
225 assert(banditdata != NULL);
226 assert(selection >= 0);
227 assert(selection < SCIPbanditGetNActions(bandit));
228
229 /* increase the mean by the incremental formula: A_n = A_n-1 + 1/n (a_n - A_n-1) */
230 delta = score - banditdata->meanscores[selection];
231 ++banditdata->counter[selection];
232 banditdata->meanscores[selection] += delta / (SCIP_Real)banditdata->counter[selection];
233
234 banditdata->nselections++;
235
236 return SCIP_OKAY;
237}
238
239/** reset callback for bandit algorithm */
240SCIP_DECL_BANDITRESET(SCIPbanditResetUcb)
241{ /*lint --e{715}*/
242 SCIP_BANDITDATA* banditdata;
243 int nactions;
244
245 assert(bufmem != NULL);
246 assert(bandit != NULL);
247
248 banditdata = SCIPbanditGetData(bandit);
249 assert(banditdata != NULL);
250 nactions = SCIPbanditGetNActions(bandit);
251
252 /* call the data reset for the given priorities */
253 SCIP_CALL( dataReset(bufmem, bandit, banditdata, priorities, nactions) );
254
255 return SCIP_OKAY;
256}
257
258/*
259 * bandit algorithm specific interface methods
260 */
261
262/** returns the upper confidence bound of a selected action */
264 SCIP_BANDIT* ucb, /**< UCB bandit algorithm */
265 int action /**< index of the queried action */
266 )
267{
268 SCIP_Real uppercb;
269 SCIP_BANDITDATA* banditdata;
270 int nactions;
271
272 assert(ucb != NULL);
273 banditdata = SCIPbanditGetData(ucb);
274 nactions = SCIPbanditGetNActions(ucb);
275 assert(action < nactions);
276
277 /* since only scores between 0 and 1 are allowed, 1.0 is a sure upper confidence bound */
278 if( banditdata->nselections < nactions )
279 return 1.0;
280
281 /* the bandit algorithm must have picked every action once */
282 assert(banditdata->counter[action] > 0);
283 uppercb = banditdata->meanscores[action];
284
285 uppercb += sqrt(banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections) / (SCIP_Real)banditdata->counter[action]);
286
287 return uppercb;
288}
289
290/** return start permutation of the UCB bandit algorithm */
292 SCIP_BANDIT* ucb /**< UCB bandit algorithm */
293 )
294{
295 SCIP_BANDITDATA* banditdata = SCIPbanditGetData(ucb);
296
297 assert(banditdata != NULL);
298
299 return banditdata->startperm;
300}
301
302/** internal method to create and reset UCB bandit algorithm */
304 BMS_BLKMEM* blkmem, /**< block memory */
305 BMS_BUFMEM* bufmem, /**< buffer memory */
306 SCIP_BANDITVTABLE* vtable, /**< virtual function table for UCB bandit algorithm */
307 SCIP_BANDIT** ucb, /**< pointer to store bandit algorithm */
308 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
309 SCIP_Real alpha, /**< parameter to increase confidence width */
310 int nactions, /**< the positive number of actions for this bandit algorithm */
311 unsigned int initseed /**< initial random seed */
312 )
313{
314 SCIP_BANDITDATA* banditdata;
315
316 if( alpha < 0.0 )
317 {
318 SCIPerrorMessage("UCB requires nonnegative alpha parameter, have %f\n", alpha);
319 return SCIP_INVALIDDATA;
320 }
321
322 SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
323 assert(banditdata != NULL);
324
325 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->counter, nactions) );
326 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->startperm, nactions) );
327 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->meanscores, nactions) );
328
329 banditdata->alpha = alpha;
330
331 SCIP_CALL( SCIPbanditCreate(ucb, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
332
333 return SCIP_OKAY;
334}
335
336/** create and reset UCB bandit algorithm */
338 SCIP* scip, /**< SCIP data structure */
339 SCIP_BANDIT** ucb, /**< pointer to store bandit algorithm */
340 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
341 SCIP_Real alpha, /**< parameter to increase confidence width */
342 int nactions, /**< the positive number of actions for this bandit algorithm */
343 unsigned int initseed /**< initial random number seed */
344 )
345{
346 SCIP_BANDITVTABLE* vtable;
347
349 if( vtable == NULL )
350 {
351 SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
352 return SCIP_INVALIDDATA;
353 }
354
356 priorities, alpha, nactions, SCIPinitializeRandomSeed(scip, initseed)) );
357
358 return SCIP_OKAY;
359}
360
361/** include virtual function table for UCB bandit algorithms */
363 SCIP* scip /**< SCIP data structure */
364 )
365{
366 SCIP_BANDITVTABLE* vtable;
367
369 SCIPbanditFreeUcb, SCIPbanditSelectUcb, SCIPbanditUpdateUcb, SCIPbanditResetUcb) );
370 assert(vtable != NULL);
371
372 return SCIP_OKAY;
373}
void SCIPbanditSetData(SCIP_BANDIT *bandit, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:200
SCIP_RETCODE SCIPbanditCreate(SCIP_BANDIT **bandit, SCIP_BANDITVTABLE *banditvtable, BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_Real *priorities, int nactions, unsigned int initseed, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:42
SCIP_BANDITDATA * SCIPbanditGetData(SCIP_BANDIT *bandit)
Definition: bandit.c:190
internal methods for bandit algorithms
static SCIP_RETCODE dataReset(BMS_BUFMEM *bufmem, SCIP_BANDIT *ucb, SCIP_BANDITDATA *banditdata, SCIP_Real *priorities, int nactions)
Definition: bandit_ucb.c:68
SCIP_DECL_BANDITFREE(SCIPbanditFreeUcb)
Definition: bandit_ucb.c:125
#define NUMEPS
Definition: bandit_ucb.c:45
SCIP_RETCODE SCIPincludeBanditvtableUcb(SCIP *scip)
Definition: bandit_ucb.c:362
SCIP_DECL_BANDITSELECT(SCIPbanditSelectUcb)
Definition: bandit_ucb.c:146
#define BANDIT_NAME
Definition: bandit_ucb.c:44
SCIP_RETCODE SCIPbanditCreateUcb(BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_BANDITVTABLE *vtable, SCIP_BANDIT **ucb, SCIP_Real *priorities, SCIP_Real alpha, int nactions, unsigned int initseed)
Definition: bandit_ucb.c:303
SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateUcb)
Definition: bandit_ucb.c:217
SCIP_DECL_BANDITRESET(SCIPbanditResetUcb)
Definition: bandit_ucb.c:240
internal methods for UCB bandit algorithm
#define NULL
Definition: def.h:262
#define LOG1P(x)
Definition: def.h:218
#define SCIP_ALLOC(x)
Definition: def.h:380
#define SCIP_Real
Definition: def.h:172
#define EPSEQ(x, y, eps)
Definition: def.h:197
#define EPSGT(x, y, eps)
Definition: def.h:200
#define SCIP_CALL(x)
Definition: def.h:369
void SCIPrandomPermuteIntArray(SCIP_RANDNUMGEN *randnumgen, int *array, int begin, int end)
Definition: misc.c:10150
int * SCIPgetStartPermutationUcb(SCIP_BANDIT *ucb)
Definition: bandit_ucb.c:291
int SCIPbanditGetNActions(SCIP_BANDIT *bandit)
Definition: bandit.c:303
SCIP_RANDNUMGEN * SCIPbanditGetRandnumgen(SCIP_BANDIT *bandit)
Definition: bandit.c:293
SCIP_BANDITVTABLE * SCIPfindBanditvtable(SCIP *scip, const char *name)
Definition: scip_bandit.c:80
SCIP_Real SCIPgetConfidenceBoundUcb(SCIP_BANDIT *ucb, int action)
Definition: bandit_ucb.c:263
SCIP_RETCODE SCIPincludeBanditvtable(SCIP *scip, SCIP_BANDITVTABLE **banditvtable, const char *name, SCIP_DECL_BANDITFREE((*banditfree)), SCIP_DECL_BANDITSELECT((*banditselect)), SCIP_DECL_BANDITUPDATE((*banditupdate)), SCIP_DECL_BANDITRESET((*banditreset)))
Definition: scip_bandit.c:48
SCIP_RETCODE SCIPcreateBanditUcb(SCIP *scip, SCIP_BANDIT **ucb, SCIP_Real *priorities, SCIP_Real alpha, int nactions, unsigned int initseed)
Definition: bandit_ucb.c:337
BMS_BUFMEM * SCIPbuffer(SCIP *scip)
Definition: scip_mem.c:72
SCIP_Real SCIPrandomGetReal(SCIP_RANDNUMGEN *randnumgen, SCIP_Real minrandval, SCIP_Real maxrandval)
Definition: misc.c:10131
unsigned int SCIPinitializeRandomSeed(SCIP *scip, unsigned int initialseedvalue)
void SCIPsortDownRealInt(SCIP_Real *realarray, int *intarray, int len)
#define BMSfreeBlockMemory(mem, ptr)
Definition: memory.h:465
#define BMSduplicateBufferMemoryArray(mem, ptr, source, num)
Definition: memory.h:737
#define BMSallocBlockMemory(mem, ptr)
Definition: memory.h:451
#define BMSfreeBufferMemoryArray(mem, ptr)
Definition: memory.h:742
#define BMSallocBlockMemoryArray(mem, ptr, num)
Definition: memory.h:454
#define BMSfreeBlockMemoryArray(mem, ptr, num)
Definition: memory.h:467
#define BMSclearMemoryArray(ptr, num)
Definition: memory.h:130
struct BMS_BlkMem BMS_BLKMEM
Definition: memory.h:437
BMS_BLKMEM * SCIPblkmem(SCIP *scip)
Definition: scip_mem.c:57
public methods for bandit algorithms
public methods for message output
#define SCIPerrorMessage
Definition: pub_message.h:64
public data structures and miscellaneous methods
methods for sorting joint arrays of various types
public methods for bandit algorithms
public methods for memory management
public methods for random numbers
struct SCIP_BanditData SCIP_BANDITDATA
Definition: type_bandit.h:56
@ SCIP_INVALIDDATA
Definition: type_retcode.h:52
@ SCIP_OKAY
Definition: type_retcode.h:42
enum SCIP_Retcode SCIP_RETCODE
Definition: type_retcode.h:63