this repo has no description
at develop 20 kB view raw
1/* -*- mode: C++; c-basic-offset: 2; indent-tabs-mode: nil -*- */ 2 3/* 4 * Main authors: 5 * Guido Tack <guido.tack@monash.edu> 6 */ 7 8/* This Source Code Form is subject to the terms of the Mozilla Public 9 * License, v. 2.0. If a copy of the MPL was not distributed with this 10 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ 11 12#include <minizinc/flat_exp.hh> 13 14namespace MiniZinc { 15 16std::vector<Expression*> get_conjuncts(Expression* start) { 17 std::vector<Expression*> conj_stack; 18 std::vector<Expression*> conjuncts; 19 conj_stack.push_back(start); 20 while (!conj_stack.empty()) { 21 Expression* e = conj_stack.back(); 22 conj_stack.pop_back(); 23 if (auto* bo = e->dynamicCast<BinOp>()) { 24 if (bo->op() == BOT_AND) { 25 conj_stack.push_back(bo->rhs()); 26 conj_stack.push_back(bo->lhs()); 27 } else { 28 conjuncts.push_back(e); 29 } 30 } else { 31 conjuncts.push_back(e); 32 } 33 } 34 return conjuncts; 35} 36 37void classify_conjunct(Expression* e, IdMap<int>& eq_occurrences, 38 IdMap<std::pair<Expression*, Expression*>>& eq_branches, 39 std::vector<Expression*>& other_branches) { 40 if (auto* bo = e->dynamicCast<BinOp>()) { 41 if (bo->op() == BOT_EQ) { 42 if (Id* ident = bo->lhs()->dynamicCast<Id>()) { 43 if (eq_branches.find(ident) == eq_branches.end()) { 44 auto it = eq_occurrences.find(ident); 45 if (it == eq_occurrences.end()) { 46 eq_occurrences.insert(ident, 1); 47 } else { 48 eq_occurrences.get(ident)++; 49 } 50 eq_branches.insert(ident, std::make_pair(bo->rhs(), bo)); 51 return; 52 } 53 } else if (Id* ident = bo->rhs()->dynamicCast<Id>()) { 54 if (eq_branches.find(ident) == eq_branches.end()) { 55 auto it = eq_occurrences.find(ident); 56 if (it == eq_occurrences.end()) { 57 eq_occurrences.insert(ident, 1); 58 } else { 59 eq_occurrences.get(ident)++; 60 } 61 eq_branches.insert(ident, std::make_pair(bo->lhs(), bo)); 62 return; 63 } 64 } 65 } 66 } 67 other_branches.push_back(e); 68} 69 70EE flatten_ite(EnvI& env, const Ctx& ctx, Expression* e, VarDecl* r, VarDecl* b) { 71 CallStackItem _csi(env, e); 72 ITE* ite = e->cast<ITE>(); 73 74 // The conditions of each branch of the if-then-else 75 std::vector<KeepAlive> conditions; 76 // Whether the right hand side of each branch is defined 77 std::vector<std::vector<KeepAlive>> defined; 78 // The right hand side of each branch 79 std::vector<std::vector<KeepAlive>> branches; 80 // Whether all branches are fixed 81 std::vector<bool> allBranchesPar; 82 83 // Compute bounds of result as union bounds of all branches 84 std::vector<std::vector<IntBounds>> r_bounds_int; 85 std::vector<bool> r_bounds_valid_int; 86 std::vector<std::vector<IntSetVal*>> r_bounds_set; 87 std::vector<bool> r_bounds_valid_set; 88 std::vector<std::vector<FloatBounds>> r_bounds_float; 89 std::vector<bool> r_bounds_valid_float; 90 91 bool allConditionsPar = true; 92 bool allDefined = true; 93 94 // The result variables of each generated conditional 95 std::vector<VarDecl*> results; 96 // The then-expressions of each generated conditional 97 std::vector<std::vector<KeepAlive>> e_then; 98 // The else-expressions of each generated conditional 99 std::vector<KeepAlive> e_else; 100 101 bool noOtherBranches = true; 102 if (ite->type() == Type::varbool() && ctx.b == C_ROOT && r == constants().varTrue) { 103 // Check if all branches are of the form x1=e1 /\ ... /\ xn=en 104 IdMap<int> eq_occurrences; 105 std::vector<IdMap<std::pair<Expression*, Expression*>>> eq_branches(ite->size() + 1); 106 std::vector<std::vector<Expression*>> other_branches(ite->size() + 1); 107 for (int i = 0; i < ite->size(); i++) { 108 auto conjuncts = get_conjuncts(ite->thenExpr(i)); 109 for (auto* c : conjuncts) { 110 classify_conjunct(c, eq_occurrences, eq_branches[i], other_branches[i]); 111 } 112 noOtherBranches = noOtherBranches && other_branches[i].empty(); 113 } 114 { 115 auto conjuncts = get_conjuncts(ite->elseExpr()); 116 for (auto* c : conjuncts) { 117 classify_conjunct(c, eq_occurrences, eq_branches[ite->size()], other_branches[ite->size()]); 118 } 119 noOtherBranches = noOtherBranches && other_branches[ite->size()].empty(); 120 } 121 for (auto& eq : eq_occurrences) { 122 if (eq.second >= ite->size()) { 123 // Any identifier that occurs in all or all but one branch gets its own conditional 124 results.push_back(eq.first->decl()); 125 e_then.emplace_back(); 126 for (int i = 0; i < ite->size(); i++) { 127 auto it = eq_branches[i].find(eq.first); 128 if (it == eq_branches[i].end()) { 129 // not found, simply push x=x 130 e_then.back().push_back(eq.first); 131 } else { 132 e_then.back().push_back(it->second.first); 133 } 134 } 135 { 136 auto it = eq_branches[ite->size()].find(eq.first); 137 if (it == eq_branches[ite->size()].end()) { 138 // not found, simply push x=x 139 e_else.emplace_back(eq.first); 140 } else { 141 e_else.emplace_back(it->second.first); 142 } 143 } 144 } else { 145 // All other identifiers are put in the vector of "other" branches 146 for (int i = 0; i <= ite->size(); i++) { 147 auto it = eq_branches[i].find(eq.first); 148 if (it != eq_branches[i].end()) { 149 other_branches[i].push_back(it->second.second); 150 noOtherBranches = false; 151 eq_branches[i].remove(eq.first); 152 } 153 } 154 } 155 } 156 if (!noOtherBranches) { 157 results.push_back(r); 158 e_then.emplace_back(); 159 for (int i = 0; i < ite->size(); i++) { 160 if (eq_branches[i].size() == 0) { 161 e_then.back().push_back(ite->thenExpr(i)); 162 } else if (other_branches[i].empty()) { 163 e_then.back().push_back(constants().literalTrue); 164 } else if (other_branches[i].size() == 1) { 165 e_then.back().push_back(other_branches[i][0]); 166 } else { 167 GCLock lock; 168 auto* al = new ArrayLit(Location().introduce(), other_branches[i]); 169 al->type(Type::varbool(1)); 170 Call* forall = new Call(Location().introduce(), constants().ids.forall, {al}); 171 forall->decl(env.model->matchFn(env, forall, false)); 172 forall->type(forall->decl()->rtype(env, {al}, false)); 173 e_then.back().push_back(forall); 174 } 175 } 176 { 177 if (eq_branches[ite->size()].size() == 0) { 178 e_else.emplace_back(ite->elseExpr()); 179 } else if (other_branches[ite->size()].empty()) { 180 e_else.emplace_back(constants().literalTrue); 181 } else if (other_branches[ite->size()].size() == 1) { 182 e_else.emplace_back(other_branches[ite->size()][0]); 183 } else { 184 GCLock lock; 185 auto* al = new ArrayLit(Location().introduce(), other_branches[ite->size()]); 186 al->type(Type::varbool(1)); 187 Call* forall = new Call(Location().introduce(), constants().ids.forall, {al}); 188 forall->decl(env.model->matchFn(env, forall, false)); 189 forall->type(forall->decl()->rtype(env, {al}, false)); 190 e_else.emplace_back(forall); 191 } 192 } 193 } 194 } else { 195 noOtherBranches = false; 196 results.push_back(r); 197 e_then.emplace_back(); 198 for (int i = 0; i < ite->size(); i++) { 199 e_then.back().push_back(ite->thenExpr(i)); 200 } 201 e_else.emplace_back(ite->elseExpr()); 202 } 203 allBranchesPar.resize(results.size()); 204 r_bounds_valid_int.resize(results.size()); 205 r_bounds_int.resize(results.size()); 206 r_bounds_valid_float.resize(results.size()); 207 r_bounds_float.resize(results.size()); 208 r_bounds_valid_set.resize(results.size()); 209 r_bounds_set.resize(results.size()); 210 defined.resize(results.size()); 211 branches.resize(results.size()); 212 for (unsigned int i = 0; i < results.size(); i++) { 213 allBranchesPar[i] = true; 214 r_bounds_valid_int[i] = true; 215 r_bounds_valid_float[i] = true; 216 r_bounds_valid_set[i] = true; 217 } 218 219 Ctx cmix; 220 cmix.b = C_MIX; 221 cmix.i = C_MIX; 222 cmix.neg = ctx.neg; 223 224 bool foundTrueBranch = false; 225 for (int i = 0; i < ite->size() && !foundTrueBranch; i++) { 226 bool cond = true; 227 EE e_if; 228 if (ite->ifExpr(i)->isa<Call>() && 229 ite->ifExpr(i)->cast<Call>()->id() == "mzn_in_root_context") { 230 e_if = EE(constants().boollit(ctx.b == C_ROOT), constants().literalTrue); 231 } else { 232 Ctx cmix_not_negated; 233 cmix_not_negated.b = C_MIX; 234 cmix_not_negated.i = C_MIX; 235 e_if = flat_exp(env, cmix_not_negated, ite->ifExpr(i), nullptr, constants().varTrue); 236 } 237 if (e_if.r()->type() == Type::parbool()) { 238 { 239 GCLock lock; 240 cond = eval_bool(env, e_if.r()); 241 } 242 if (cond) { 243 if (allConditionsPar) { 244 // no var conditions before this one, so we can simply emit 245 // the then branch 246 return flat_exp(env, ctx, ite->thenExpr(i), r, b); 247 } 248 // had var conditions, so we have to take them into account 249 // and emit new conditional clause 250 // add another condition and definedness variable 251 conditions.emplace_back(constants().literalTrue); 252 for (unsigned int j = 0; j < results.size(); j++) { 253 EE ethen = flat_exp(env, cmix, e_then[j][i](), nullptr, nullptr); 254 assert(ethen.b()); 255 defined[j].push_back(ethen.b); 256 allDefined = allDefined && (ethen.b() == constants().literalTrue); 257 branches[j].push_back(ethen.r); 258 if (ethen.r()->type().isvar()) { 259 allBranchesPar[j] = false; 260 } 261 } 262 foundTrueBranch = true; 263 } else { 264 GCLock lock; 265 conditions.emplace_back(constants().literalFalse); 266 for (unsigned int j = 0; j < results.size(); j++) { 267 defined[j].push_back(constants().literalTrue); 268 branches[j].push_back(create_dummy_value(env, e_then[j][i]()->type())); 269 } 270 } 271 } else { 272 allConditionsPar = false; 273 // add current condition and definedness variable 274 conditions.push_back(e_if.r); 275 276 for (unsigned int j = 0; j < results.size(); j++) { 277 // flatten the then branch 278 EE ethen = flat_exp(env, cmix, e_then[j][i](), nullptr, nullptr); 279 280 assert(ethen.b()); 281 defined[j].push_back(ethen.b); 282 allDefined = allDefined && (ethen.b() == constants().literalTrue); 283 branches[j].push_back(ethen.r); 284 if (ethen.r()->type().isvar()) { 285 allBranchesPar[j] = false; 286 } 287 } 288 } 289 // update bounds 290 291 if (cond) { 292 for (unsigned int j = 0; j < results.size(); j++) { 293 if (r_bounds_valid_int[j] && e_then[j][i]()->type().isint()) { 294 GCLock lock; 295 IntBounds ib_then = compute_int_bounds(env, branches[j][i]()); 296 if (ib_then.valid) { 297 r_bounds_int[j].push_back(ib_then); 298 } 299 r_bounds_valid_int[j] = r_bounds_valid_int[j] && ib_then.valid; 300 } else if (r_bounds_valid_set[j] && e_then[j][i]()->type().isIntSet()) { 301 GCLock lock; 302 IntSetVal* isv = compute_intset_bounds(env, branches[j][i]()); 303 if (isv != nullptr) { 304 r_bounds_set[j].push_back(isv); 305 } 306 r_bounds_valid_set[j] = r_bounds_valid_set[j] && (isv != nullptr); 307 } else if (r_bounds_valid_float[j] && e_then[j][i]()->type().isfloat()) { 308 GCLock lock; 309 FloatBounds fb_then = compute_float_bounds(env, branches[j][i]()); 310 if (fb_then.valid) { 311 r_bounds_float[j].push_back(fb_then); 312 } 313 r_bounds_valid_float[j] = r_bounds_valid_float[j] && fb_then.valid; 314 } 315 } 316 } 317 } 318 319 if (allConditionsPar) { 320 // no var condition, and all par conditions were false, 321 // so simply emit else branch 322 return flat_exp(env, ctx, ite->elseExpr(), r, b); 323 } 324 325 for (auto& result : results) { 326 if (result == nullptr) { 327 // need to introduce new result variable 328 GCLock lock; 329 auto* ti = new TypeInst(Location().introduce(), ite->type(), nullptr); 330 result = new_vardecl(env, Ctx(), ti, nullptr, nullptr, nullptr); 331 } 332 } 333 334 if (conditions.back()() != constants().literalTrue) { 335 // The last condition wasn't fixed to true, we need to look at the else branch 336 conditions.emplace_back(constants().literalTrue); 337 338 for (unsigned int j = 0; j < results.size(); j++) { 339 // flatten else branch 340 EE eelse = flat_exp(env, cmix, e_else[j](), nullptr, nullptr); 341 assert(eelse.b()); 342 defined[j].push_back(eelse.b); 343 allDefined = allDefined && (eelse.b() == constants().literalTrue); 344 branches[j].push_back(eelse.r); 345 if (eelse.r()->type().isvar()) { 346 allBranchesPar[j] = false; 347 } 348 349 if (r_bounds_valid_int[j] && e_else[j]()->type().isint()) { 350 GCLock lock; 351 IntBounds ib_else = compute_int_bounds(env, eelse.r()); 352 if (ib_else.valid) { 353 r_bounds_int[j].push_back(ib_else); 354 } 355 r_bounds_valid_int[j] = r_bounds_valid_int[j] && ib_else.valid; 356 } else if (r_bounds_valid_set[j] && e_else[j]()->type().isIntSet()) { 357 GCLock lock; 358 IntSetVal* isv = compute_intset_bounds(env, eelse.r()); 359 if (isv != nullptr) { 360 r_bounds_set[j].push_back(isv); 361 } 362 r_bounds_valid_set[j] = r_bounds_valid_set[j] && (isv != nullptr); 363 } else if (r_bounds_valid_float[j] && e_else[j]()->type().isfloat()) { 364 GCLock lock; 365 FloatBounds fb_else = compute_float_bounds(env, eelse.r()); 366 if (fb_else.valid) { 367 r_bounds_float[j].push_back(fb_else); 368 } 369 r_bounds_valid_float[j] = r_bounds_valid_float[j] && fb_else.valid; 370 } 371 } 372 } 373 374 // update domain of result variable with bounds from all branches 375 376 for (unsigned int j = 0; j < results.size(); j++) { 377 VarDecl* nr = results[j]; 378 GCLock lock; 379 if (r_bounds_valid_int[j] && ite->type().isint()) { 380 IntVal lb = IntVal::infinity(); 381 IntVal ub = -IntVal::infinity(); 382 for (auto& i : r_bounds_int[j]) { 383 lb = std::min(lb, i.l); 384 ub = std::max(ub, i.u); 385 } 386 if (nr->ti()->domain() != nullptr) { 387 IntSetVal* isv = eval_intset(env, nr->ti()->domain()); 388 Ranges::Const<IntVal> ite_r(lb, ub); 389 IntSetRanges isv_r(isv); 390 Ranges::Inter<IntVal, Ranges::Const<IntVal>, IntSetRanges> inter(ite_r, isv_r); 391 IntSetVal* isv_new = IntSetVal::ai(inter); 392 if (isv_new->card() != isv->card()) { 393 auto* r_dom = new SetLit(Location().introduce(), isv_new); 394 nr->ti()->domain(r_dom); 395 } 396 } else { 397 auto* r_dom = new SetLit(Location().introduce(), IntSetVal::a(lb, ub)); 398 nr->ti()->domain(r_dom); 399 nr->ti()->setComputedDomain(true); 400 } 401 } else if (r_bounds_valid_set[j] && ite->type().isIntSet()) { 402 IntSetVal* isv_branches = IntSetVal::a(); 403 for (auto& i : r_bounds_set[j]) { 404 IntSetRanges i0(isv_branches); 405 IntSetRanges i1(i); 406 Ranges::Union<IntVal, IntSetRanges, IntSetRanges> u(i0, i1); 407 isv_branches = IntSetVal::ai(u); 408 } 409 if (nr->ti()->domain() != nullptr) { 410 IntSetVal* isv = eval_intset(env, nr->ti()->domain()); 411 IntSetRanges isv_r(isv); 412 IntSetRanges isv_branches_r(isv_branches); 413 Ranges::Inter<IntVal, IntSetRanges, IntSetRanges> inter(isv_branches_r, isv_r); 414 IntSetVal* isv_new = IntSetVal::ai(inter); 415 if (isv_new->card() != isv->card()) { 416 auto* r_dom = new SetLit(Location().introduce(), isv_new); 417 nr->ti()->domain(r_dom); 418 } 419 } else { 420 auto* r_dom = new SetLit(Location().introduce(), isv_branches); 421 nr->ti()->domain(r_dom); 422 nr->ti()->setComputedDomain(true); 423 } 424 } else if (r_bounds_valid_float[j] && ite->type().isfloat()) { 425 FloatVal lb = FloatVal::infinity(); 426 FloatVal ub = -FloatVal::infinity(); 427 for (auto& i : r_bounds_float[j]) { 428 lb = std::min(lb, i.l); 429 ub = std::max(ub, i.u); 430 } 431 if (nr->ti()->domain() != nullptr) { 432 FloatSetVal* isv = eval_floatset(env, nr->ti()->domain()); 433 Ranges::Const<FloatVal> ite_r(lb, ub); 434 FloatSetRanges isv_r(isv); 435 Ranges::Inter<FloatVal, Ranges::Const<FloatVal>, FloatSetRanges> inter(ite_r, isv_r); 436 FloatSetVal* fsv_new = FloatSetVal::ai(inter); 437 auto* r_dom = new SetLit(Location().introduce(), fsv_new); 438 nr->ti()->domain(r_dom); 439 } else { 440 auto* r_dom = new SetLit(Location().introduce(), FloatSetVal::a(lb, ub)); 441 nr->ti()->domain(r_dom); 442 nr->ti()->setComputedDomain(true); 443 } 444 } 445 } 446 447 // Create ite predicate calls 448 GCLock lock; 449 auto* al_cond = new ArrayLit(Location().introduce(), conditions); 450 al_cond->type(Type::varbool(1)); 451 for (unsigned int j = 0; j < results.size(); j++) { 452 auto* al_branches = new ArrayLit(Location().introduce(), branches[j]); 453 Type branches_t = results[j]->type(); 454 branches_t.dim(1); 455 branches_t.ti(allBranchesPar[j] ? Type::TI_PAR : Type::TI_VAR); 456 al_branches->type(branches_t); 457 Call* ite_pred = new Call(ite->loc().introduce(), ASTString("if_then_else"), 458 {al_cond, al_branches, results[j]->id()}); 459 ite_pred->decl(env.model->matchFn(env, ite_pred, false)); 460 ite_pred->type(Type::varbool()); 461 (void)flat_exp(env, Ctx(), ite_pred, constants().varTrue, constants().varTrue); 462 } 463 EE ret; 464 if (noOtherBranches) { 465 ret.r = constants().varTrue->id(); 466 } else { 467 ret.r = results.back()->id(); 468 } 469 if (allDefined) { 470 bind(env, Ctx(), b, constants().literalTrue); 471 ret.b = constants().literalTrue; 472 } else { 473 // Otherwise, constraint linking conditions, b and the definedness variables 474 if (b == nullptr) { 475 CallStackItem _csi(env, new StringLit(Location().introduce(), "b")); 476 b = new_vardecl(env, Ctx(), new TypeInst(Location().introduce(), Type::varbool()), nullptr, 477 nullptr, nullptr); 478 } 479 ret.b = b->id(); 480 481 std::vector<Expression*> defined_conjunctions(ite->size() + 1); 482 for (unsigned int i = 0; i < ite->size() + 1; i++) { 483 std::vector<Expression*> def_i; 484 for (auto& j : defined) { 485 assert(j.size() > i); 486 if (j[i]() != constants().literalTrue) { 487 def_i.push_back(j[i]()); 488 } 489 } 490 if (def_i.empty()) { 491 defined_conjunctions[i] = constants().literalTrue; 492 } else if (def_i.size() == 1) { 493 defined_conjunctions[i] = def_i[0]; 494 } else { 495 auto* al = new ArrayLit(Location().introduce(), def_i); 496 al->type(Type::varbool(1)); 497 Call* forall = new Call(Location().introduce(), constants().ids.forall, {al}); 498 forall->decl(env.model->matchFn(env, forall, false)); 499 forall->type(forall->decl()->rtype(env, {al}, false)); 500 defined_conjunctions[i] = forall; 501 } 502 } 503 auto* al_defined = new ArrayLit(Location().introduce(), defined_conjunctions); 504 al_defined->type(Type::varbool(1)); 505 Call* ite_defined_pred = new Call(ite->loc().introduce(), ASTString("if_then_else_partiality"), 506 {al_cond, al_defined, b->id()}); 507 ite_defined_pred->decl(env.model->matchFn(env, ite_defined_pred, false)); 508 ite_defined_pred->type(Type::varbool()); 509 (void)flat_exp(env, Ctx(), ite_defined_pred, constants().varTrue, constants().varTrue); 510 } 511 512 return ret; 513} 514 515} // namespace MiniZinc