Scippy

    SCIP

    Solving Constraint Integer Programs

    bandit_exp3ix.c
    Go to the documentation of this file.
    1/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
    2/* */
    3/* This file is part of the program and library */
    4/* SCIP --- Solving Constraint Integer Programs */
    5/* */
    6/* Copyright (c) 2002-2025 Zuse Institute Berlin (ZIB) */
    7/* */
    8/* Licensed under the Apache License, Version 2.0 (the "License"); */
    9/* you may not use this file except in compliance with the License. */
    10/* You may obtain a copy of the License at */
    11/* */
    12/* http://www.apache.org/licenses/LICENSE-2.0 */
    13/* */
    14/* Unless required by applicable law or agreed to in writing, software */
    15/* distributed under the License is distributed on an "AS IS" BASIS, */
    16/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */
    17/* See the License for the specific language governing permissions and */
    18/* limitations under the License. */
    19/* */
    20/* You should have received a copy of the Apache-2.0 license */
    21/* along with SCIP; see the file LICENSE. If not visit scipopt.org. */
    22/* */
    23/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
    24
    25/**@file bandit_exp3ix.c
    26 * @ingroup OTHER_CFILES
    27 * @brief methods for Exp.3-IX bandit selection
    28 * @author Antonia Chmiela
    29 */
    30
    31/*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
    32
    33#include "scip/bandit.h"
    34#include "scip/bandit_exp3ix.h"
    35#include "scip/pub_bandit.h"
    36#include "scip/pub_message.h"
    37#include "scip/pub_misc.h"
    38#include "scip/scip_bandit.h"
    39#include "scip/scip_mem.h"
    41
    42#define BANDIT_NAME "exp3ix"
    43
    44/*
    45 * Data structures
    46 */
    47
    48/** implementation specific data of Exp.3 bandit algorithm */
    49struct SCIP_BanditData
    50{
    51 SCIP_Real* weights; /**< exponential weight for each arm */
    52 SCIP_Real weightsum; /**< the sum of all weights */
    53 int iter; /**< current iteration counter to compute parameters gamma_t and eta_t */
    54};
    55
    56/*
    57 * Local methods
    58 */
    59
    60/*
    61 * Callback methods of bandit algorithm
    62 */
    63
    64/** callback to free bandit specific data structures */
    65SCIP_DECL_BANDITFREE(SCIPbanditFreeExp3IX)
    66{ /*lint --e{715}*/
    67 SCIP_BANDITDATA* banditdata;
    68 int nactions;
    69 assert(bandit != NULL);
    70
    71 banditdata = SCIPbanditGetData(bandit);
    72 assert(banditdata != NULL);
    73 nactions = SCIPbanditGetNActions(bandit);
    74
    75 BMSfreeBlockMemoryArray(blkmem, &banditdata->weights, nactions);
    76
    77 BMSfreeBlockMemory(blkmem, &banditdata);
    78
    79 SCIPbanditSetData(bandit, NULL);
    80
    81 return SCIP_OKAY;
    82}
    83
    84/** selection callback for bandit selector */
    85SCIP_DECL_BANDITSELECT(SCIPbanditSelectExp3IX)
    86{ /*lint --e{715}*/
    87 SCIP_BANDITDATA* banditdata;
    88 SCIP_RANDNUMGEN* rng;
    89 SCIP_Real* weights;
    90 SCIP_Real weightsum;
    91 int i;
    92 int nactions;
    93 SCIP_Real psum;
    94 SCIP_Real randnr;
    95
    96 assert(bandit != NULL);
    97 assert(selection != NULL);
    98
    99 banditdata = SCIPbanditGetData(bandit);
    100 assert(banditdata != NULL);
    101 rng = SCIPbanditGetRandnumgen(bandit);
    102 assert(rng != NULL);
    103 nactions = SCIPbanditGetNActions(bandit);
    104
    105 /* initialize some local variables to speed up probability computations */
    106 weightsum = banditdata->weightsum;
    107 weights = banditdata->weights;
    108
    109 /* draw a random number between 0 and 1 */
    110 randnr = SCIPrandomGetReal(rng, 0.0, 1.0);
    111
    112 /* loop over probability distribution until rand is reached
    113 * the loop terminates without looking at the last action,
    114 * which is then selected automatically if the target probability
    115 * is not reached earlier
    116 */
    117 psum = 0.0;
    118 for( i = 0; i < nactions - 1; ++i )
    119 {
    120 SCIP_Real prob;
    121
    122 /* compute the probability for arm i */
    123 prob = weights[i] / weightsum;
    124 psum += prob;
    125
    126 /* break and select element if target probability is reached */
    127 if( randnr <= psum )
    128 break;
    129 }
    130
    131 /* select element i, which is the last action in case that the break statement hasn't been reached */
    132 *selection = i;
    133
    134 return SCIP_OKAY;
    135}
    136
    137/** compute gamma_t */
    138static
    140 int nactions, /**< the positive number of actions for this bandit algorithm */
    141 int t /**< current iteration */
    142 )
    143{
    144 return sqrt(log((SCIP_Real)nactions) / (4.0 * (SCIP_Real)t * (SCIP_Real)nactions));
    145}
    146
    147/** update callback for bandit algorithm */
    148SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateExp3IX)
    149{ /*lint --e{715}*/
    150 SCIP_BANDITDATA* banditdata;
    151 SCIP_Real etaparam;
    152 SCIP_Real lossestim;
    153 SCIP_Real prob;
    154 SCIP_Real weightsum;
    155 SCIP_Real newweightsum;
    156 SCIP_Real* weights;
    157 SCIP_Real gammaparam;
    158 int nactions;
    159
    160 assert(bandit != NULL);
    161
    162 banditdata = SCIPbanditGetData(bandit);
    163 assert(banditdata != NULL);
    164 nactions = SCIPbanditGetNActions(bandit);
    165
    166 assert(selection >= 0);
    167 assert(selection < nactions);
    168
    169 weights = banditdata->weights;
    170 weightsum = banditdata->weightsum;
    171 newweightsum = weightsum;
    172 gammaparam = SCIPcomputeGamma(nactions, banditdata->iter);
    173 etaparam = 2.0 * gammaparam;
    174
    175 /* probability of selection */
    176 prob = weights[selection] / weightsum;
    177
    178 /* estimated loss */
    179 lossestim = (1.0 - score) / (prob + gammaparam);
    180 assert(lossestim >= 0);
    181
    182 /* update the observation for the current arm */
    183 newweightsum -= weights[selection];
    184 weights[selection] *= exp(-etaparam * lossestim);
    185 newweightsum += weights[selection];
    186
    187 banditdata->weightsum = newweightsum;
    188
    189 /* increase iteration counter */
    190 banditdata->iter += 1;
    191
    192 return SCIP_OKAY;
    193}
    194
    195/** reset callback for bandit algorithm */
    196SCIP_DECL_BANDITRESET(SCIPbanditResetExp3IX)
    197{ /*lint --e{715}*/
    198 SCIP_BANDITDATA* banditdata;
    199 SCIP_Real* weights;
    200 int nactions;
    201 int i;
    202
    203 assert(bandit != NULL);
    204
    205 banditdata = SCIPbanditGetData(bandit);
    206 assert(banditdata != NULL);
    207 nactions = SCIPbanditGetNActions(bandit);
    208 weights = banditdata->weights;
    209
    210 assert(nactions > 0);
    211
    212 /* initialize all weights with 1.0 */
    213 for( i = 0; i < nactions; ++i )
    214 weights[i] = 1.0;
    215
    216 banditdata->weightsum = (SCIP_Real)nactions;
    217
    218 /* set iteration counter to 1 */
    219 banditdata->iter = 1;
    220
    221 return SCIP_OKAY;
    222}
    223
    224
    225/*
    226 * bandit algorithm specific interface methods
    227 */
    228
    229/** direct bandit creation method for the core where no SCIP pointer is available */
    231 BMS_BLKMEM* blkmem, /**< block memory data structure */
    232 BMS_BUFMEM* bufmem, /**< buffer memory */
    233 SCIP_BANDITVTABLE* vtable, /**< virtual function table for callback functions of Exp.3-IX */
    234 SCIP_BANDIT** exp3ix, /**< pointer to store bandit algorithm */
    235 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
    236 int nactions, /**< the positive number of actions for this bandit algorithm */
    237 unsigned int initseed /**< initial random seed */
    238 )
    239{
    240 SCIP_BANDITDATA* banditdata;
    241
    242 SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
    243 assert(banditdata != NULL);
    244
    245 banditdata->iter = 1;
    246
    247 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->weights, nactions) );
    248
    249 SCIP_CALL( SCIPbanditCreate(exp3ix, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
    250
    251 return SCIP_OKAY;
    252}
    253
    254/** creates and resets an Exp.3-IX bandit algorithm using \p scip pointer */
    256 SCIP* scip, /**< SCIP data structure */
    257 SCIP_BANDIT** exp3ix, /**< pointer to store bandit algorithm */
    258 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
    259 int nactions, /**< the positive number of actions for this bandit algorithm */
    260 unsigned int initseed /**< initial seed for random number generation */
    261 )
    262{
    263 SCIP_BANDITVTABLE* vtable;
    264
    266 if( vtable == NULL )
    267 {
    268 SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
    269 return SCIP_INVALIDDATA;
    270 }
    271
    273 priorities, nactions, SCIPinitializeRandomSeed(scip, initseed)) );
    274
    275 return SCIP_OKAY;
    276}
    277
    278/** returns probability to play an action */
    280 SCIP_BANDIT* exp3ix, /**< bandit algorithm */
    281 int action /**< index of the requested action */
    282 )
    283{
    284 SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3ix);
    285
    286 assert(banditdata->weightsum > 0.0);
    287 assert(SCIPbanditGetNActions(exp3ix) > 0);
    288
    289 return banditdata->weights[action] / banditdata->weightsum;
    290}
    291
    292/** include virtual function table for Exp.3-IX bandit algorithms */
    294 SCIP* scip /**< SCIP data structure */
    295 )
    296{
    297 SCIP_BANDITVTABLE* vtable;
    298
    300 SCIPbanditFreeExp3IX, SCIPbanditSelectExp3IX, SCIPbanditUpdateExp3IX, SCIPbanditResetExp3IX) );
    301 assert(vtable != NULL);
    302
    303 return SCIP_OKAY;
    304}
    void SCIPbanditSetData(SCIP_BANDIT *bandit, SCIP_BANDITDATA *banditdata)
    Definition: bandit.c:200
    SCIP_RETCODE SCIPbanditCreate(SCIP_BANDIT **bandit, SCIP_BANDITVTABLE *banditvtable, BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_Real *priorities, int nactions, unsigned int initseed, SCIP_BANDITDATA *banditdata)
    Definition: bandit.c:42
    SCIP_BANDITDATA * SCIPbanditGetData(SCIP_BANDIT *bandit)
    Definition: bandit.c:190
    internal methods for bandit algorithms
    static SCIP_Real SCIPcomputeGamma(int nactions, int t)
    SCIP_DECL_BANDITFREE(SCIPbanditFreeExp3IX)
    Definition: bandit_exp3ix.c:65
    SCIP_RETCODE SCIPincludeBanditvtableExp3IX(SCIP *scip)
    SCIP_DECL_BANDITSELECT(SCIPbanditSelectExp3IX)
    Definition: bandit_exp3ix.c:85
    SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateExp3IX)
    #define BANDIT_NAME
    Definition: bandit_exp3ix.c:42
    SCIP_RETCODE SCIPbanditCreateExp3IX(BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_BANDITVTABLE *vtable, SCIP_BANDIT **exp3ix, SCIP_Real *priorities, int nactions, unsigned int initseed)
    SCIP_DECL_BANDITRESET(SCIPbanditResetExp3IX)
    internal methods for Exp.3-IX bandit algorithm
    #define NULL
    Definition: def.h:248
    #define SCIP_ALLOC(x)
    Definition: def.h:366
    #define SCIP_Real
    Definition: def.h:156
    #define SCIP_CALL(x)
    Definition: def.h:355
    int SCIPbanditGetNActions(SCIP_BANDIT *bandit)
    Definition: bandit.c:303
    SCIP_Real SCIPgetProbabilityExp3IX(SCIP_BANDIT *exp3ix, int action)
    SCIP_RANDNUMGEN * SCIPbanditGetRandnumgen(SCIP_BANDIT *bandit)
    Definition: bandit.c:293
    SCIP_BANDITVTABLE * SCIPfindBanditvtable(SCIP *scip, const char *name)
    Definition: scip_bandit.c:80
    SCIP_RETCODE SCIPcreateBanditExp3IX(SCIP *scip, SCIP_BANDIT **exp3ix, SCIP_Real *priorities, int nactions, unsigned int initseed)
    SCIP_RETCODE SCIPincludeBanditvtable(SCIP *scip, SCIP_BANDITVTABLE **banditvtable, const char *name, SCIP_DECL_BANDITFREE((*banditfree)), SCIP_DECL_BANDITSELECT((*banditselect)), SCIP_DECL_BANDITUPDATE((*banditupdate)), SCIP_DECL_BANDITRESET((*banditreset)))
    Definition: scip_bandit.c:48
    BMS_BLKMEM * SCIPblkmem(SCIP *scip)
    Definition: scip_mem.c:57
    BMS_BUFMEM * SCIPbuffer(SCIP *scip)
    Definition: scip_mem.c:72
    SCIP_Real SCIPrandomGetReal(SCIP_RANDNUMGEN *randnumgen, SCIP_Real minrandval, SCIP_Real maxrandval)
    Definition: misc.c:10245
    unsigned int SCIPinitializeRandomSeed(SCIP *scip, unsigned int initialseedvalue)
    #define BMSfreeBlockMemory(mem, ptr)
    Definition: memory.h:465
    #define BMSallocBlockMemory(mem, ptr)
    Definition: memory.h:451
    #define BMSallocBlockMemoryArray(mem, ptr, num)
    Definition: memory.h:454
    #define BMSfreeBlockMemoryArray(mem, ptr, num)
    Definition: memory.h:467
    struct BMS_BlkMem BMS_BLKMEM
    Definition: memory.h:437
    public methods for bandit algorithms
    public methods for message output
    #define SCIPerrorMessage
    Definition: pub_message.h:64
    public data structures and miscellaneous methods
    public methods for bandit algorithms
    public methods for memory management
    public methods for random numbers
    struct SCIP_BanditData SCIP_BANDITDATA
    Definition: type_bandit.h:56
    @ SCIP_INVALIDDATA
    Definition: type_retcode.h:52
    @ SCIP_OKAY
    Definition: type_retcode.h:42
    enum SCIP_Retcode SCIP_RETCODE
    Definition: type_retcode.h:63