this repo has no description
at develop 18 kB view raw
1/* -*- mode: C++; c-basic-offset: 2; indent-tabs-mode: nil -*- */ 2 3/* 4 * Main authors: 5 * Guido Tack <guido.tack@monash.edu> 6 */ 7 8/* This Source Code Form is subject to the terms of the Mozilla Public 9 * License, v. 2.0. If a copy of the MPL was not distributed with this 10 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ 11 12#include <minizinc/flat_exp.hh> 13 14namespace MiniZinc { 15 16std::vector<Expression*> get_conjuncts(Expression* e) { 17 std::vector<Expression*> conj_stack; 18 std::vector<Expression*> conjuncts; 19 conj_stack.push_back(e); 20 while (!conj_stack.empty()) { 21 Expression* e = conj_stack.back(); 22 conj_stack.pop_back(); 23 if (BinOp* bo = e->dyn_cast<BinOp>()) { 24 if (bo->op() == BOT_AND) { 25 conj_stack.push_back(bo->rhs()); 26 conj_stack.push_back(bo->lhs()); 27 } else { 28 conjuncts.push_back(e); 29 } 30 } else { 31 conjuncts.push_back(e); 32 } 33 } 34 return conjuncts; 35} 36 37void classify_conjunct(Expression* e, IdMap<int>& eq_occurrences, 38 IdMap<std::pair<Expression*, Expression*>>& eq_branches, 39 std::vector<Expression*>& other_branches) { 40 if (BinOp* bo = e->dyn_cast<BinOp>()) { 41 if (bo->op() == BOT_EQ) { 42 if (Id* ident = bo->lhs()->dyn_cast<Id>()) { 43 if (eq_branches.find(ident) == eq_branches.end()) { 44 IdMap<int>::iterator it = eq_occurrences.find(ident); 45 if (it == eq_occurrences.end()) { 46 eq_occurrences.insert(ident, 1); 47 } else { 48 eq_occurrences.get(ident)++; 49 } 50 eq_branches.insert(ident, std::make_pair(bo->rhs(), bo)); 51 return; 52 } 53 } else if (Id* ident = bo->rhs()->dyn_cast<Id>()) { 54 if (eq_branches.find(ident) == eq_branches.end()) { 55 IdMap<int>::iterator it = eq_occurrences.find(ident); 56 if (it == eq_occurrences.end()) { 57 eq_occurrences.insert(ident, 1); 58 } else { 59 eq_occurrences.get(ident)++; 60 } 61 eq_branches.insert(ident, std::make_pair(bo->lhs(), bo)); 62 return; 63 } 64 } 65 } 66 } 67 other_branches.push_back(e); 68} 69 70EE flatten_ite(EnvI& env, Ctx ctx, Expression* e, VarDecl* r, VarDecl* b) { 71 CallStackItem _csi(env, e); 72 ITE* ite = e->cast<ITE>(); 73 74 // The conditions of each branch of the if-then-else 75 std::vector<KeepAlive> conditions; 76 // Whether the right hand side of each branch is defined 77 std::vector<std::vector<KeepAlive>> defined; 78 // The right hand side of each branch 79 std::vector<std::vector<KeepAlive>> branches; 80 // Whether all branches are fixed 81 std::vector<bool> allBranchesPar; 82 83 // Compute bounds of result as union bounds of all branches 84 std::vector<std::vector<IntBounds>> r_bounds_int; 85 std::vector<bool> r_bounds_valid_int; 86 std::vector<std::vector<IntSetVal*>> r_bounds_set; 87 std::vector<bool> r_bounds_valid_set; 88 std::vector<std::vector<FloatBounds>> r_bounds_float; 89 std::vector<bool> r_bounds_valid_float; 90 91 bool allConditionsPar = true; 92 bool allDefined = true; 93 94 // The result variables of each generated conditional 95 std::vector<VarDecl*> results; 96 // The then-expressions of each generated conditional 97 std::vector<std::vector<Expression*>> e_then; 98 // The else-expressions of each generated conditional 99 std::vector<Expression*> e_else; 100 101 bool noOtherBranches = true; 102 if (ite->type() == Type::varbool() && ctx.b == C_ROOT && r == constants().var_true) { 103 // Check if all branches are of the form x1=e1 /\ ... /\ xn=en 104 IdMap<int> eq_occurrences; 105 std::vector<IdMap<std::pair<Expression*, Expression*>>> eq_branches(ite->size() + 1); 106 std::vector<std::vector<Expression*>> other_branches(ite->size() + 1); 107 for (int i = 0; i < ite->size(); i++) { 108 auto conjuncts = get_conjuncts(ite->e_then(i)); 109 for (auto c : conjuncts) { 110 classify_conjunct(c, eq_occurrences, eq_branches[i], other_branches[i]); 111 } 112 noOtherBranches = noOtherBranches && other_branches[i].empty(); 113 } 114 { 115 auto conjuncts = get_conjuncts(ite->e_else()); 116 for (auto c : conjuncts) { 117 classify_conjunct(c, eq_occurrences, eq_branches[ite->size()], other_branches[ite->size()]); 118 } 119 noOtherBranches = noOtherBranches && other_branches[ite->size()].empty(); 120 } 121 for (auto& e : eq_occurrences) { 122 if (e.second >= ite->size()) { 123 // Any identifier that occurs in all or all but one branch gets its own conditional 124 results.push_back(e.first->decl()); 125 e_then.push_back(std::vector<Expression*>()); 126 for (int i = 0; i < ite->size(); i++) { 127 IdMap<std::pair<Expression*, Expression*>>::iterator it = eq_branches[i].find(e.first); 128 if (it == eq_branches[i].end()) { 129 // not found, simply push x=x 130 e_then.back().push_back(e.first); 131 } else { 132 e_then.back().push_back(it->second.first); 133 } 134 } 135 { 136 IdMap<std::pair<Expression*, Expression*>>::iterator it = 137 eq_branches[ite->size()].find(e.first); 138 if (it == eq_branches[ite->size()].end()) { 139 // not found, simply push x=x 140 e_else.push_back(e.first); 141 } else { 142 e_else.push_back(it->second.first); 143 } 144 } 145 } else { 146 // All other identifiers are put in the vector of "other" branches 147 for (int i = 0; i <= ite->size(); i++) { 148 IdMap<std::pair<Expression*, Expression*>>::iterator it = eq_branches[i].find(e.first); 149 if (it != eq_branches[i].end()) { 150 other_branches[i].push_back(it->second.second); 151 noOtherBranches = false; 152 eq_branches[i].remove(e.first); 153 } 154 } 155 } 156 } 157 if (!noOtherBranches) { 158 results.push_back(r); 159 e_then.push_back(std::vector<Expression*>()); 160 for (int i = 0; i < ite->size(); i++) { 161 if (eq_branches[i].size() == 0) { 162 e_then.back().push_back(ite->e_then(i)); 163 } else if (other_branches[i].size() == 0) { 164 e_then.back().push_back(constants().lit_true); 165 } else if (other_branches[i].size() == 1) { 166 e_then.back().push_back(other_branches[i][0]); 167 } else { 168 ArrayLit* al = new ArrayLit(Location().introduce(), other_branches[i]); 169 al->type(Type::varbool(1)); 170 Call* forall = new Call(Location().introduce(), constants().ids.forall, {al}); 171 forall->decl(env.model->matchFn(env, forall, false)); 172 forall->type(forall->decl()->rtype(env, {al}, false)); 173 e_then.back().push_back(forall); 174 } 175 } 176 { 177 if (eq_branches[ite->size()].size() == 0) { 178 e_else.push_back(ite->e_else()); 179 } else if (other_branches[ite->size()].size() == 0) { 180 e_else.push_back(constants().lit_true); 181 } else if (other_branches[ite->size()].size() == 1) { 182 e_else.push_back(other_branches[ite->size()][0]); 183 } else { 184 ArrayLit* al = new ArrayLit(Location().introduce(), other_branches[ite->size()]); 185 al->type(Type::varbool(1)); 186 Call* forall = new Call(Location().introduce(), constants().ids.forall, {al}); 187 forall->decl(env.model->matchFn(env, forall, false)); 188 forall->type(forall->decl()->rtype(env, {al}, false)); 189 e_else.push_back(forall); 190 } 191 } 192 } 193 } else { 194 noOtherBranches = false; 195 results.push_back(r); 196 e_then.push_back(std::vector<Expression*>()); 197 for (int i = 0; i < ite->size(); i++) { 198 e_then.back().push_back(ite->e_then(i)); 199 } 200 e_else.push_back(ite->e_else()); 201 } 202 allBranchesPar.resize(results.size()); 203 r_bounds_valid_int.resize(results.size()); 204 r_bounds_int.resize(results.size()); 205 r_bounds_valid_float.resize(results.size()); 206 r_bounds_float.resize(results.size()); 207 r_bounds_valid_set.resize(results.size()); 208 r_bounds_set.resize(results.size()); 209 defined.resize(results.size()); 210 branches.resize(results.size()); 211 for (unsigned int i = 0; i < results.size(); i++) { 212 allBranchesPar[i] = true; 213 r_bounds_valid_int[i] = true; 214 r_bounds_valid_float[i] = true; 215 r_bounds_valid_set[i] = true; 216 } 217 218 Ctx cmix; 219 cmix.b = C_MIX; 220 cmix.i = C_MIX; 221 222 for (int i = 0; i < ite->size(); i++) { 223 bool cond = true; 224 EE e_if; 225 if (ite->e_if(i)->isa<Call>() && ite->e_if(i)->cast<Call>()->id() == "mzn_in_root_context") { 226 e_if = EE(constants().boollit(ctx.b == C_ROOT), constants().lit_true); 227 } else { 228 e_if = flat_exp(env, cmix, ite->e_if(i), NULL, constants().var_true); 229 } 230 if (e_if.r()->type() == Type::parbool()) { 231 { 232 GCLock lock; 233 cond = eval_bool(env, e_if.r()); 234 } 235 if (cond) { 236 if (allConditionsPar || conditions.size() == 0) { 237 // no var conditions before this one, so we can simply emit 238 // the then branch 239 return flat_exp(env, ctx, ite->e_then(i), r, b); 240 } 241 // had var conditions, so we have to take them into account 242 // and emit new conditional clause 243 Ctx cmix; 244 cmix.b = C_MIX; 245 cmix.i = C_MIX; 246 // add another condition and definedness variable 247 conditions.push_back(constants().lit_true); 248 for (unsigned int j = 0; j < results.size(); j++) { 249 EE ethen = flat_exp(env, cmix, e_then[j][i], NULL, NULL); 250 assert(ethen.b()); 251 defined[j].push_back(ethen.b); 252 allDefined = allDefined && (ethen.b() == constants().lit_true); 253 branches[j].push_back(ethen.r); 254 if (ethen.r()->type().isvar()) { 255 allBranchesPar[j] = false; 256 } 257 } 258 break; 259 } 260 } else { 261 allConditionsPar = false; 262 // add current condition and definedness variable 263 conditions.push_back(e_if.r); 264 265 for (unsigned int j = 0; j < results.size(); j++) { 266 // flatten the then branch 267 EE ethen = flat_exp(env, cmix, e_then[j][i], NULL, NULL); 268 269 assert(ethen.b()); 270 defined[j].push_back(ethen.b); 271 allDefined = allDefined && (ethen.b() == constants().lit_true); 272 branches[j].push_back(ethen.r); 273 if (ethen.r()->type().isvar()) { 274 allBranchesPar[j] = false; 275 } 276 } 277 } 278 // update bounds 279 280 if (cond) { 281 for (unsigned int j = 0; j < results.size(); j++) { 282 if (r_bounds_valid_int[j] && e_then[j][i]->type().isint()) { 283 GCLock lock; 284 IntBounds ib_then = compute_int_bounds(env, e_then[j][i]); 285 if (ib_then.valid) r_bounds_int[j].push_back(ib_then); 286 r_bounds_valid_int[j] = r_bounds_valid_int[j] && ib_then.valid; 287 } else if (r_bounds_valid_set[j] && e_then[j][i]->type().isintset()) { 288 GCLock lock; 289 IntSetVal* isv = compute_intset_bounds(env, e_then[j][i]); 290 if (isv) r_bounds_set[j].push_back(isv); 291 r_bounds_valid_set[j] = r_bounds_valid_set[j] && isv; 292 } else if (r_bounds_valid_float[j] && e_then[j][i]->type().isfloat()) { 293 GCLock lock; 294 FloatBounds fb_then = compute_float_bounds(env, e_then[j][i]); 295 if (fb_then.valid) r_bounds_float[j].push_back(fb_then); 296 r_bounds_valid_float[j] = r_bounds_valid_float[j] && fb_then.valid; 297 } 298 } 299 } 300 } 301 302 if (allConditionsPar) { 303 // no var condition, and all par conditions were false, 304 // so simply emit else branch 305 return flat_exp(env, ctx, ite->e_else(), r, b); 306 } 307 308 for (unsigned int j = 0; j < results.size(); j++) { 309 if (results[j] == NULL) { 310 // need to introduce new result variable 311 GCLock lock; 312 TypeInst* ti = new TypeInst(Location().introduce(), ite->type(), NULL); 313 results[j] = newVarDecl(env, Ctx(), ti, NULL, NULL, NULL); 314 } 315 } 316 317 if (conditions.back()() != constants().lit_true) { 318 // The last condition wasn't fixed to true, we need to look at the else branch 319 conditions.push_back(constants().lit_true); 320 321 for (unsigned int j = 0; j < results.size(); j++) { 322 VarDecl* nr = results[j]; 323 324 // update bounds of result with bounds of else branch 325 326 if (r_bounds_valid_int[j] && e_else[j]->type().isint()) { 327 GCLock lock; 328 IntBounds ib_else = compute_int_bounds(env, e_else[j]); 329 if (ib_else.valid) { 330 r_bounds_int[j].push_back(ib_else); 331 IntVal lb = IntVal::infinity(); 332 IntVal ub = -IntVal::infinity(); 333 for (unsigned int i = 0; i < r_bounds_int[j].size(); i++) { 334 lb = std::min(lb, r_bounds_int[j][i].l); 335 ub = std::max(ub, r_bounds_int[j][i].u); 336 } 337 if (results[j]) { 338 IntBounds orig_r_bounds = compute_int_bounds(env, results[j]->id()); 339 if (orig_r_bounds.valid) { 340 lb = std::max(lb, orig_r_bounds.l); 341 ub = std::min(ub, orig_r_bounds.u); 342 } 343 } 344 SetLit* r_dom = new SetLit(Location().introduce(), IntSetVal::a(lb, ub)); 345 nr->ti()->domain(r_dom); 346 } 347 } else if (r_bounds_valid_set[j] && e_else[j]->type().isintset()) { 348 GCLock lock; 349 IntSetVal* isv_else = compute_intset_bounds(env, e_else[j]); 350 if (isv_else) { 351 IntSetVal* isv = isv_else; 352 for (unsigned int i = 0; i < r_bounds_set[j].size(); i++) { 353 IntSetRanges i0(isv); 354 IntSetRanges i1(r_bounds_set[j][i]); 355 Ranges::Union<IntVal, IntSetRanges, IntSetRanges> u(i0, i1); 356 isv = IntSetVal::ai(u); 357 } 358 if (results[j]) { 359 IntSetVal* orig_r_bounds = compute_intset_bounds(env, results[j]->id()); 360 if (orig_r_bounds) { 361 IntSetRanges i0(isv); 362 IntSetRanges i1(orig_r_bounds); 363 Ranges::Inter<IntVal, IntSetRanges, IntSetRanges> inter(i0, i1); 364 isv = IntSetVal::ai(inter); 365 } 366 } 367 SetLit* r_dom = new SetLit(Location().introduce(), isv); 368 nr->ti()->domain(r_dom); 369 } 370 } else if (r_bounds_valid_float[j] && e_else[j]->type().isfloat()) { 371 GCLock lock; 372 FloatBounds fb_else = compute_float_bounds(env, e_else[j]); 373 if (fb_else.valid) { 374 FloatVal lb = fb_else.l; 375 FloatVal ub = fb_else.u; 376 for (unsigned int i = 0; i < r_bounds_float[j].size(); i++) { 377 lb = std::min(lb, r_bounds_float[j][i].l); 378 ub = std::max(ub, r_bounds_float[j][i].u); 379 } 380 if (results[j]) { 381 FloatBounds orig_r_bounds = compute_float_bounds(env, results[j]->id()); 382 if (orig_r_bounds.valid) { 383 lb = std::max(lb, orig_r_bounds.l); 384 ub = std::min(ub, orig_r_bounds.u); 385 } 386 } 387 BinOp* r_dom = 388 new BinOp(Location().introduce(), FloatLit::a(lb), BOT_DOTDOT, FloatLit::a(ub)); 389 r_dom->type(Type::parfloat(1)); 390 nr->ti()->domain(r_dom); 391 } 392 } 393 394 // flatten else branch 395 EE eelse = flat_exp(env, cmix, e_else[j], NULL, NULL); 396 assert(eelse.b()); 397 defined[j].push_back(eelse.b); 398 allDefined = allDefined && (eelse.b() == constants().lit_true); 399 branches[j].push_back(eelse.r); 400 if (eelse.r()->type().isvar()) { 401 allBranchesPar[j] = false; 402 } 403 } 404 } 405 406 // Create ite predicate calls 407 GCLock lock; 408 ArrayLit* al_cond = new ArrayLit(Location().introduce(), conditions); 409 al_cond->type(Type::varbool(1)); 410 for (unsigned int j = 0; j < results.size(); j++) { 411 ArrayLit* al_branches = new ArrayLit(Location().introduce(), branches[j]); 412 Type branches_t = results[j]->type(); 413 branches_t.dim(1); 414 branches_t.ti(allBranchesPar[j] ? Type::TI_PAR : Type::TI_VAR); 415 al_branches->type(branches_t); 416 Call* ite_pred = new Call(ite->loc().introduce(), ASTString("if_then_else"), 417 {al_cond, al_branches, results[j]->id()}); 418 ite_pred->decl(env.model->matchFn(env, ite_pred, false)); 419 ite_pred->type(Type::varbool()); 420 (void)flat_exp(env, Ctx(), ite_pred, constants().var_true, constants().var_true); 421 } 422 EE ret; 423 if (noOtherBranches) { 424 ret.r = constants().var_true->id(); 425 } else { 426 ret.r = results.back()->id(); 427 } 428 if (allDefined) { 429 bind(env, ctx, b, constants().lit_true); 430 ret.b = constants().lit_true; 431 } else { 432 // Otherwise, constraint linking conditions, b and the definedness variables 433 if (b == NULL) { 434 CallStackItem _csi(env, new StringLit(Location().introduce(), "b")); 435 b = newVarDecl(env, Ctx(), new TypeInst(Location().introduce(), Type::varbool()), NULL, NULL, 436 NULL); 437 } 438 ret.b = b->id(); 439 440 std::vector<Expression*> defined_conjunctions(ite->size() + 1); 441 for (unsigned int i = 0; i < ite->size(); i++) { 442 std::vector<Expression*> def_i; 443 for (unsigned int j = 0; j < defined.size(); j++) { 444 if (defined[j][i]() != constants().lit_true) { 445 def_i.push_back(defined[j][i]()); 446 } 447 } 448 if (def_i.size() == 0) { 449 defined_conjunctions[i] = constants().lit_true; 450 } else if (def_i.size() == 1) { 451 defined_conjunctions[i] = def_i[0]; 452 } else { 453 ArrayLit* al = new ArrayLit(Location().introduce(), def_i); 454 al->type(Type::varbool(1)); 455 Call* forall = new Call(Location().introduce(), constants().ids.forall, {al}); 456 forall->decl(env.model->matchFn(env, forall, false)); 457 forall->type(forall->decl()->rtype(env, {al}, false)); 458 defined_conjunctions[i] = forall; 459 } 460 } 461 ArrayLit* al_defined = new ArrayLit(Location().introduce(), defined_conjunctions); 462 al_defined->type(Type::varbool(1)); 463 Call* ite_defined_pred = new Call(ite->loc().introduce(), ASTString("if_then_else_partiality"), 464 {al_cond, al_defined, b->id()}); 465 ite_defined_pred->decl(env.model->matchFn(env, ite_defined_pred, false)); 466 ite_defined_pred->type(Type::varbool()); 467 (void)flat_exp(env, Ctx(), ite_defined_pred, constants().var_true, constants().var_true); 468 } 469 470 return ret; 471} 472 473} // namespace MiniZinc