Scippy

    SCIP

    Solving Constraint Integer Programs

    bandit_ucb.c
    Go to the documentation of this file.
    1/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
    2/* */
    3/* This file is part of the program and library */
    4/* SCIP --- Solving Constraint Integer Programs */
    5/* */
    6/* Copyright (c) 2002-2025 Zuse Institute Berlin (ZIB) */
    7/* */
    8/* Licensed under the Apache License, Version 2.0 (the "License"); */
    9/* you may not use this file except in compliance with the License. */
    10/* You may obtain a copy of the License at */
    11/* */
    12/* http://www.apache.org/licenses/LICENSE-2.0 */
    13/* */
    14/* Unless required by applicable law or agreed to in writing, software */
    15/* distributed under the License is distributed on an "AS IS" BASIS, */
    16/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */
    17/* See the License for the specific language governing permissions and */
    18/* limitations under the License. */
    19/* */
    20/* You should have received a copy of the Apache-2.0 license */
    21/* along with SCIP; see the file LICENSE. If not visit scipopt.org. */
    22/* */
    23/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
    24
    25/**@file bandit_ucb.c
    26 * @ingroup OTHER_CFILES
    27 * @brief methods for UCB bandit selection
    28 * @author Gregor Hendel
    29 */
    30
    31/*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
    32
    33#include "scip/bandit.h"
    34#include "scip/bandit_ucb.h"
    35#include "scip/pub_bandit.h"
    36#include "scip/pub_message.h"
    37#include "scip/pub_misc.h"
    38#include "scip/pub_misc_sort.h"
    39#include "scip/scip_bandit.h"
    40#include "scip/scip_mem.h"
    42
    43
    44#define BANDIT_NAME "ucb"
    45#define NUMEPS 1e-6
    46
    47/*
    48 * Data structures
    49 */
    50
    51/** implementation specific data of UCB bandit algorithm */
    52struct SCIP_BanditData
    53{
    54 int nselections; /**< counter for the number of selections */
    55 int* counter; /**< array of counters how often every action has been chosen */
    56 int* startperm; /**< indices for starting permutation */
    57 SCIP_Real* meanscores; /**< array of average scores for the actions */
    58 SCIP_Real alpha; /**< parameter to increase confidence width */
    59};
    60
    61
    62/*
    63 * Local methods
    64 */
    65
    66/** data reset method */
    67static
    69 BMS_BUFMEM* bufmem, /**< buffer memory */
    70 SCIP_BANDIT* ucb, /**< ucb bandit algorithm */
    71 SCIP_BANDITDATA* banditdata, /**< UCB bandit data structure */
    72 SCIP_Real* priorities, /**< priorities for start permutation, or NULL */
    73 int nactions /**< number of actions */
    74 )
    75{
    76 int i;
    77 SCIP_RANDNUMGEN* rng;
    78
    79 assert(bufmem != NULL);
    80 assert(ucb != NULL);
    81 assert(nactions > 0);
    82
    83 /* clear counters and scores */
    84 BMSclearMemoryArray(banditdata->counter, nactions);
    85 BMSclearMemoryArray(banditdata->meanscores, nactions);
    86 banditdata->nselections = 0;
    87
    88 rng = SCIPbanditGetRandnumgen(ucb);
    89 assert(rng != NULL);
    90
    91 /* initialize start permutation as identity */
    92 for( i = 0; i < nactions; ++i )
    93 banditdata->startperm[i] = i;
    94
    95 /* prepare the start permutation in decreasing order of priority */
    96 if( priorities != NULL )
    97 {
    98 SCIP_Real* prioritycopy;
    99
    100 SCIP_ALLOC( BMSduplicateBufferMemoryArray(bufmem, &prioritycopy, priorities, nactions) );
    101
    102 /* randomly wiggle priorities a little bit to make them unique */
    103 for( i = 0; i < nactions; ++i )
    104 prioritycopy[i] += SCIPrandomGetReal(rng, -NUMEPS, NUMEPS);
    105
    106 SCIPsortDownRealInt(prioritycopy, banditdata->startperm, nactions);
    107
    108 BMSfreeBufferMemoryArray(bufmem, &prioritycopy);
    109 }
    110 else
    111 {
    112 /* use a random start permutation */
    113 SCIPrandomPermuteIntArray(rng, banditdata->startperm, 0, nactions);
    114 }
    115
    116 return SCIP_OKAY;
    117}
    118
    119
    120/*
    121 * Callback methods of bandit algorithm
    122 */
    123
    124/** callback to free bandit specific data structures */
    125SCIP_DECL_BANDITFREE(SCIPbanditFreeUcb)
    126{ /*lint --e{715}*/
    127 SCIP_BANDITDATA* banditdata;
    128 int nactions;
    129 assert(bandit != NULL);
    130
    131 banditdata = SCIPbanditGetData(bandit);
    132 assert(banditdata != NULL);
    133 nactions = SCIPbanditGetNActions(bandit);
    134
    135 BMSfreeBlockMemoryArray(blkmem, &banditdata->counter, nactions);
    136 BMSfreeBlockMemoryArray(blkmem, &banditdata->startperm, nactions);
    137 BMSfreeBlockMemoryArray(blkmem, &banditdata->meanscores, nactions);
    138 BMSfreeBlockMemory(blkmem, &banditdata);
    139
    140 SCIPbanditSetData(bandit, NULL);
    141
    142 return SCIP_OKAY;
    143}
    144
    145/** selection callback for bandit selector */
    146SCIP_DECL_BANDITSELECT(SCIPbanditSelectUcb)
    147{ /*lint --e{715}*/
    148 SCIP_BANDITDATA* banditdata;
    149 int nactions;
    150 int* counter;
    151
    152 assert(bandit != NULL);
    153 assert(selection != NULL);
    154
    155 banditdata = SCIPbanditGetData(bandit);
    156 assert(banditdata != NULL);
    157 nactions = SCIPbanditGetNActions(bandit);
    158
    159 counter = banditdata->counter;
    160 /* select the next uninitialized action from the start permutation */
    161 if( banditdata->nselections < nactions )
    162 {
    163 *selection = banditdata->startperm[banditdata->nselections];
    164 assert(counter[*selection] == 0);
    165 }
    166 else
    167 {
    168 /* select the action with the highest upper confidence bound */
    169 SCIP_Real* meanscores;
    170 SCIP_Real widthfactor;
    171 SCIP_Real maxucb;
    172 int i;
    174 meanscores = banditdata->meanscores;
    175
    176 assert(rng != NULL);
    177 assert(meanscores != NULL);
    178
    179 /* compute the confidence width factor that is common for all actions */
    180 widthfactor = banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections);
    181 widthfactor = sqrt(widthfactor);
    182 maxucb = -1.0;
    183
    184 /* loop over the actions and determine the maximum upper confidence bound.
    185 * The upper confidence bound of an action is the sum of its mean score
    186 * plus a confidence term that decreases with increasing number of observations of
    187 * this action.
    188 */
    189 for( i = 0; i < nactions; ++i )
    190 {
    191 SCIP_Real uppercb;
    192 SCIP_Real rootcount;
    193 assert(counter[i] > 0);
    194
    195 /* compute the upper confidence bound for action i */
    196 uppercb = meanscores[i];
    197 rootcount = sqrt((SCIP_Real)counter[i]);
    198 uppercb += widthfactor / rootcount;
    199 assert(uppercb > 0);
    200
    201 /* update maximum, breaking ties uniformly at random */
    202 if( EPSGT(uppercb, maxucb, NUMEPS) || (EPSEQ(uppercb, maxucb, NUMEPS) && SCIPrandomGetReal(rng, 0.0, 1.0) >= 0.5) )
    203 {
    204 maxucb = uppercb;
    205 *selection = i;
    206 }
    207 }
    208 }
    209
    210 assert(*selection >= 0);
    211 assert(*selection < nactions);
    212
    213 return SCIP_OKAY;
    214}
    215
    216/** update callback for bandit algorithm */
    217SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateUcb)
    218{ /*lint --e{715}*/
    219 SCIP_BANDITDATA* banditdata;
    220 SCIP_Real delta;
    221
    222 assert(bandit != NULL);
    223
    224 banditdata = SCIPbanditGetData(bandit);
    225 assert(banditdata != NULL);
    226 assert(selection >= 0);
    227 assert(selection < SCIPbanditGetNActions(bandit));
    228
    229 /* increase the mean by the incremental formula: A_n = A_n-1 + 1/n (a_n - A_n-1) */
    230 delta = score - banditdata->meanscores[selection];
    231 ++banditdata->counter[selection];
    232 banditdata->meanscores[selection] += delta / (SCIP_Real)banditdata->counter[selection];
    233
    234 banditdata->nselections++;
    235
    236 return SCIP_OKAY;
    237}
    238
    239/** reset callback for bandit algorithm */
    240SCIP_DECL_BANDITRESET(SCIPbanditResetUcb)
    241{ /*lint --e{715}*/
    242 SCIP_BANDITDATA* banditdata;
    243 int nactions;
    244
    245 assert(bufmem != NULL);
    246 assert(bandit != NULL);
    247
    248 banditdata = SCIPbanditGetData(bandit);
    249 assert(banditdata != NULL);
    250 nactions = SCIPbanditGetNActions(bandit);
    251
    252 /* call the data reset for the given priorities */
    253 SCIP_CALL( dataReset(bufmem, bandit, banditdata, priorities, nactions) );
    254
    255 return SCIP_OKAY;
    256}
    257
    258/*
    259 * bandit algorithm specific interface methods
    260 */
    261
    262/** returns the upper confidence bound of a selected action */
    264 SCIP_BANDIT* ucb, /**< UCB bandit algorithm */
    265 int action /**< index of the queried action */
    266 )
    267{
    268 SCIP_Real uppercb;
    269 SCIP_BANDITDATA* banditdata;
    270 int nactions;
    271
    272 assert(ucb != NULL);
    273 banditdata = SCIPbanditGetData(ucb);
    274 nactions = SCIPbanditGetNActions(ucb);
    275 assert(action < nactions);
    276
    277 /* since only scores between 0 and 1 are allowed, 1.0 is a sure upper confidence bound */
    278 if( banditdata->nselections < nactions )
    279 return 1.0;
    280
    281 /* the bandit algorithm must have picked every action once */
    282 assert(banditdata->counter[action] > 0);
    283 uppercb = banditdata->meanscores[action];
    284
    285 uppercb += sqrt(banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections) / (SCIP_Real)banditdata->counter[action]);
    286
    287 return uppercb;
    288}
    289
    290/** return start permutation of the UCB bandit algorithm */
    292 SCIP_BANDIT* ucb /**< UCB bandit algorithm */
    293 )
    294{
    295 SCIP_BANDITDATA* banditdata = SCIPbanditGetData(ucb);
    296
    297 assert(banditdata != NULL);
    298
    299 return banditdata->startperm;
    300}
    301
    302/** internal method to create and reset UCB bandit algorithm */
    304 BMS_BLKMEM* blkmem, /**< block memory */
    305 BMS_BUFMEM* bufmem, /**< buffer memory */
    306 SCIP_BANDITVTABLE* vtable, /**< virtual function table for UCB bandit algorithm */
    307 SCIP_BANDIT** ucb, /**< pointer to store bandit algorithm */
    308 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
    309 SCIP_Real alpha, /**< parameter to increase confidence width */
    310 int nactions, /**< the positive number of actions for this bandit algorithm */
    311 unsigned int initseed /**< initial random seed */
    312 )
    313{
    314 SCIP_BANDITDATA* banditdata;
    315
    316 if( alpha < 0.0 )
    317 {
    318 SCIPerrorMessage("UCB requires nonnegative alpha parameter, have %f\n", alpha);
    319 return SCIP_INVALIDDATA;
    320 }
    321
    322 SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
    323 assert(banditdata != NULL);
    324
    325 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->counter, nactions) );
    326 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->startperm, nactions) );
    327 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->meanscores, nactions) );
    328
    329 banditdata->alpha = alpha;
    330
    331 SCIP_CALL( SCIPbanditCreate(ucb, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
    332
    333 return SCIP_OKAY;
    334}
    335
    336/** create and reset UCB bandit algorithm */
    338 SCIP* scip, /**< SCIP data structure */
    339 SCIP_BANDIT** ucb, /**< pointer to store bandit algorithm */
    340 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
    341 SCIP_Real alpha, /**< parameter to increase confidence width */
    342 int nactions, /**< the positive number of actions for this bandit algorithm */
    343 unsigned int initseed /**< initial random number seed */
    344 )
    345{
    346 SCIP_BANDITVTABLE* vtable;
    347
    349 if( vtable == NULL )
    350 {
    351 SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
    352 return SCIP_INVALIDDATA;
    353 }
    354
    356 priorities, alpha, nactions, SCIPinitializeRandomSeed(scip, initseed)) );
    357
    358 return SCIP_OKAY;
    359}
    360
    361/** include virtual function table for UCB bandit algorithms */
    363 SCIP* scip /**< SCIP data structure */
    364 )
    365{
    366 SCIP_BANDITVTABLE* vtable;
    367
    369 SCIPbanditFreeUcb, SCIPbanditSelectUcb, SCIPbanditUpdateUcb, SCIPbanditResetUcb) );
    370 assert(vtable != NULL);
    371
    372 return SCIP_OKAY;
    373}
    void SCIPbanditSetData(SCIP_BANDIT *bandit, SCIP_BANDITDATA *banditdata)
    Definition: bandit.c:200
    SCIP_RETCODE SCIPbanditCreate(SCIP_BANDIT **bandit, SCIP_BANDITVTABLE *banditvtable, BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_Real *priorities, int nactions, unsigned int initseed, SCIP_BANDITDATA *banditdata)
    Definition: bandit.c:42
    SCIP_BANDITDATA * SCIPbanditGetData(SCIP_BANDIT *bandit)
    Definition: bandit.c:190
    internal methods for bandit algorithms
    static SCIP_RETCODE dataReset(BMS_BUFMEM *bufmem, SCIP_BANDIT *ucb, SCIP_BANDITDATA *banditdata, SCIP_Real *priorities, int nactions)
    Definition: bandit_ucb.c:68
    SCIP_DECL_BANDITFREE(SCIPbanditFreeUcb)
    Definition: bandit_ucb.c:125
    #define NUMEPS
    Definition: bandit_ucb.c:45
    SCIP_RETCODE SCIPincludeBanditvtableUcb(SCIP *scip)
    Definition: bandit_ucb.c:362
    SCIP_DECL_BANDITSELECT(SCIPbanditSelectUcb)
    Definition: bandit_ucb.c:146
    #define BANDIT_NAME
    Definition: bandit_ucb.c:44
    SCIP_RETCODE SCIPbanditCreateUcb(BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_BANDITVTABLE *vtable, SCIP_BANDIT **ucb, SCIP_Real *priorities, SCIP_Real alpha, int nactions, unsigned int initseed)
    Definition: bandit_ucb.c:303
    SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateUcb)
    Definition: bandit_ucb.c:217
    SCIP_DECL_BANDITRESET(SCIPbanditResetUcb)
    Definition: bandit_ucb.c:240
    internal methods for UCB bandit algorithm
    #define NULL
    Definition: def.h:248
    #define LOG1P(x)
    Definition: def.h:204
    #define SCIP_ALLOC(x)
    Definition: def.h:366
    #define SCIP_Real
    Definition: def.h:156
    #define EPSEQ(x, y, eps)
    Definition: def.h:183
    #define EPSGT(x, y, eps)
    Definition: def.h:186
    #define SCIP_CALL(x)
    Definition: def.h:355
    void SCIPrandomPermuteIntArray(SCIP_RANDNUMGEN *randnumgen, int *array, int begin, int end)
    Definition: misc.c:10264
    int * SCIPgetStartPermutationUcb(SCIP_BANDIT *ucb)
    Definition: bandit_ucb.c:291
    int SCIPbanditGetNActions(SCIP_BANDIT *bandit)
    Definition: bandit.c:303
    SCIP_RANDNUMGEN * SCIPbanditGetRandnumgen(SCIP_BANDIT *bandit)
    Definition: bandit.c:293
    SCIP_BANDITVTABLE * SCIPfindBanditvtable(SCIP *scip, const char *name)
    Definition: scip_bandit.c:80
    SCIP_Real SCIPgetConfidenceBoundUcb(SCIP_BANDIT *ucb, int action)
    Definition: bandit_ucb.c:263
    SCIP_RETCODE SCIPincludeBanditvtable(SCIP *scip, SCIP_BANDITVTABLE **banditvtable, const char *name, SCIP_DECL_BANDITFREE((*banditfree)), SCIP_DECL_BANDITSELECT((*banditselect)), SCIP_DECL_BANDITUPDATE((*banditupdate)), SCIP_DECL_BANDITRESET((*banditreset)))
    Definition: scip_bandit.c:48
    SCIP_RETCODE SCIPcreateBanditUcb(SCIP *scip, SCIP_BANDIT **ucb, SCIP_Real *priorities, SCIP_Real alpha, int nactions, unsigned int initseed)
    Definition: bandit_ucb.c:337
    BMS_BLKMEM * SCIPblkmem(SCIP *scip)
    Definition: scip_mem.c:57
    BMS_BUFMEM * SCIPbuffer(SCIP *scip)
    Definition: scip_mem.c:72
    SCIP_Real SCIPrandomGetReal(SCIP_RANDNUMGEN *randnumgen, SCIP_Real minrandval, SCIP_Real maxrandval)
    Definition: misc.c:10245
    unsigned int SCIPinitializeRandomSeed(SCIP *scip, unsigned int initialseedvalue)
    void SCIPsortDownRealInt(SCIP_Real *realarray, int *intarray, int len)
    #define BMSfreeBlockMemory(mem, ptr)
    Definition: memory.h:465
    #define BMSduplicateBufferMemoryArray(mem, ptr, source, num)
    Definition: memory.h:737
    #define BMSallocBlockMemory(mem, ptr)
    Definition: memory.h:451
    #define BMSfreeBufferMemoryArray(mem, ptr)
    Definition: memory.h:742
    #define BMSallocBlockMemoryArray(mem, ptr, num)
    Definition: memory.h:454
    #define BMSfreeBlockMemoryArray(mem, ptr, num)
    Definition: memory.h:467
    #define BMSclearMemoryArray(ptr, num)
    Definition: memory.h:130
    struct BMS_BlkMem BMS_BLKMEM
    Definition: memory.h:437
    public methods for bandit algorithms
    public methods for message output
    #define SCIPerrorMessage
    Definition: pub_message.h:64
    public data structures and miscellaneous methods
    methods for sorting joint arrays of various types
    public methods for bandit algorithms
    public methods for memory management
    public methods for random numbers
    struct SCIP_BanditData SCIP_BANDITDATA
    Definition: type_bandit.h:56
    @ SCIP_INVALIDDATA
    Definition: type_retcode.h:52
    @ SCIP_OKAY
    Definition: type_retcode.h:42
    enum SCIP_Retcode SCIP_RETCODE
    Definition: type_retcode.h:63