this repo has no description
1/* -*- mode: C++; c-basic-offset: 2; indent-tabs-mode: nil -*- */
2
3/*
4 * Main authors:
5 * Guido Tack <guido.tack@monash.edu>
6 */
7
8/* This Source Code Form is subject to the terms of the Mozilla Public
9 * License, v. 2.0. If a copy of the MPL was not distributed with this
10 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
11
12#include <minizinc/flat_exp.hh>
13
14namespace MiniZinc {
15
16std::vector<Expression*> get_conjuncts(Expression* e) {
17 std::vector<Expression*> conj_stack;
18 std::vector<Expression*> conjuncts;
19 conj_stack.push_back(e);
20 while (!conj_stack.empty()) {
21 Expression* e = conj_stack.back();
22 conj_stack.pop_back();
23 if (BinOp* bo = e->dyn_cast<BinOp>()) {
24 if (bo->op() == BOT_AND) {
25 conj_stack.push_back(bo->rhs());
26 conj_stack.push_back(bo->lhs());
27 } else {
28 conjuncts.push_back(e);
29 }
30 } else {
31 conjuncts.push_back(e);
32 }
33 }
34 return conjuncts;
35}
36
37void classify_conjunct(Expression* e, IdMap<int>& eq_occurrences,
38 IdMap<std::pair<Expression*, Expression*>>& eq_branches,
39 std::vector<Expression*>& other_branches) {
40 if (BinOp* bo = e->dyn_cast<BinOp>()) {
41 if (bo->op() == BOT_EQ) {
42 if (Id* ident = bo->lhs()->dyn_cast<Id>()) {
43 if (eq_branches.find(ident) == eq_branches.end()) {
44 IdMap<int>::iterator it = eq_occurrences.find(ident);
45 if (it == eq_occurrences.end()) {
46 eq_occurrences.insert(ident, 1);
47 } else {
48 eq_occurrences.get(ident)++;
49 }
50 eq_branches.insert(ident, std::make_pair(bo->rhs(), bo));
51 return;
52 }
53 } else if (Id* ident = bo->rhs()->dyn_cast<Id>()) {
54 if (eq_branches.find(ident) == eq_branches.end()) {
55 IdMap<int>::iterator it = eq_occurrences.find(ident);
56 if (it == eq_occurrences.end()) {
57 eq_occurrences.insert(ident, 1);
58 } else {
59 eq_occurrences.get(ident)++;
60 }
61 eq_branches.insert(ident, std::make_pair(bo->lhs(), bo));
62 return;
63 }
64 }
65 }
66 }
67 other_branches.push_back(e);
68}
69
70EE flatten_ite(EnvI& env, Ctx ctx, Expression* e, VarDecl* r, VarDecl* b) {
71 CallStackItem _csi(env, e);
72 ITE* ite = e->cast<ITE>();
73
74 // The conditions of each branch of the if-then-else
75 std::vector<KeepAlive> conditions;
76 // Whether the right hand side of each branch is defined
77 std::vector<std::vector<KeepAlive>> defined;
78 // The right hand side of each branch
79 std::vector<std::vector<KeepAlive>> branches;
80 // Whether all branches are fixed
81 std::vector<bool> allBranchesPar;
82
83 // Compute bounds of result as union bounds of all branches
84 std::vector<std::vector<IntBounds>> r_bounds_int;
85 std::vector<bool> r_bounds_valid_int;
86 std::vector<std::vector<IntSetVal*>> r_bounds_set;
87 std::vector<bool> r_bounds_valid_set;
88 std::vector<std::vector<FloatBounds>> r_bounds_float;
89 std::vector<bool> r_bounds_valid_float;
90
91 bool allConditionsPar = true;
92 bool allDefined = true;
93
94 // The result variables of each generated conditional
95 std::vector<VarDecl*> results;
96 // The then-expressions of each generated conditional
97 std::vector<std::vector<Expression*>> e_then;
98 // The else-expressions of each generated conditional
99 std::vector<Expression*> e_else;
100
101 bool noOtherBranches = true;
102 if (ite->type() == Type::varbool() && ctx.b == C_ROOT && r == constants().var_true) {
103 // Check if all branches are of the form x1=e1 /\ ... /\ xn=en
104 IdMap<int> eq_occurrences;
105 std::vector<IdMap<std::pair<Expression*, Expression*>>> eq_branches(ite->size() + 1);
106 std::vector<std::vector<Expression*>> other_branches(ite->size() + 1);
107 for (int i = 0; i < ite->size(); i++) {
108 auto conjuncts = get_conjuncts(ite->e_then(i));
109 for (auto c : conjuncts) {
110 classify_conjunct(c, eq_occurrences, eq_branches[i], other_branches[i]);
111 }
112 noOtherBranches = noOtherBranches && other_branches[i].empty();
113 }
114 {
115 auto conjuncts = get_conjuncts(ite->e_else());
116 for (auto c : conjuncts) {
117 classify_conjunct(c, eq_occurrences, eq_branches[ite->size()], other_branches[ite->size()]);
118 }
119 noOtherBranches = noOtherBranches && other_branches[ite->size()].empty();
120 }
121 for (auto& e : eq_occurrences) {
122 if (e.second >= ite->size()) {
123 // Any identifier that occurs in all or all but one branch gets its own conditional
124 results.push_back(e.first->decl());
125 e_then.push_back(std::vector<Expression*>());
126 for (int i = 0; i < ite->size(); i++) {
127 IdMap<std::pair<Expression*, Expression*>>::iterator it = eq_branches[i].find(e.first);
128 if (it == eq_branches[i].end()) {
129 // not found, simply push x=x
130 e_then.back().push_back(e.first);
131 } else {
132 e_then.back().push_back(it->second.first);
133 }
134 }
135 {
136 IdMap<std::pair<Expression*, Expression*>>::iterator it =
137 eq_branches[ite->size()].find(e.first);
138 if (it == eq_branches[ite->size()].end()) {
139 // not found, simply push x=x
140 e_else.push_back(e.first);
141 } else {
142 e_else.push_back(it->second.first);
143 }
144 }
145 } else {
146 // All other identifiers are put in the vector of "other" branches
147 for (int i = 0; i <= ite->size(); i++) {
148 IdMap<std::pair<Expression*, Expression*>>::iterator it = eq_branches[i].find(e.first);
149 if (it != eq_branches[i].end()) {
150 other_branches[i].push_back(it->second.second);
151 noOtherBranches = false;
152 eq_branches[i].remove(e.first);
153 }
154 }
155 }
156 }
157 if (!noOtherBranches) {
158 results.push_back(r);
159 e_then.push_back(std::vector<Expression*>());
160 for (int i = 0; i < ite->size(); i++) {
161 if (eq_branches[i].size() == 0) {
162 e_then.back().push_back(ite->e_then(i));
163 } else if (other_branches[i].size() == 0) {
164 e_then.back().push_back(constants().lit_true);
165 } else if (other_branches[i].size() == 1) {
166 e_then.back().push_back(other_branches[i][0]);
167 } else {
168 ArrayLit* al = new ArrayLit(Location().introduce(), other_branches[i]);
169 al->type(Type::varbool(1));
170 Call* forall = new Call(Location().introduce(), constants().ids.forall, {al});
171 forall->decl(env.model->matchFn(env, forall, false));
172 forall->type(forall->decl()->rtype(env, {al}, false));
173 e_then.back().push_back(forall);
174 }
175 }
176 {
177 if (eq_branches[ite->size()].size() == 0) {
178 e_else.push_back(ite->e_else());
179 } else if (other_branches[ite->size()].size() == 0) {
180 e_else.push_back(constants().lit_true);
181 } else if (other_branches[ite->size()].size() == 1) {
182 e_else.push_back(other_branches[ite->size()][0]);
183 } else {
184 ArrayLit* al = new ArrayLit(Location().introduce(), other_branches[ite->size()]);
185 al->type(Type::varbool(1));
186 Call* forall = new Call(Location().introduce(), constants().ids.forall, {al});
187 forall->decl(env.model->matchFn(env, forall, false));
188 forall->type(forall->decl()->rtype(env, {al}, false));
189 e_else.push_back(forall);
190 }
191 }
192 }
193 } else {
194 noOtherBranches = false;
195 results.push_back(r);
196 e_then.push_back(std::vector<Expression*>());
197 for (int i = 0; i < ite->size(); i++) {
198 e_then.back().push_back(ite->e_then(i));
199 }
200 e_else.push_back(ite->e_else());
201 }
202 allBranchesPar.resize(results.size());
203 r_bounds_valid_int.resize(results.size());
204 r_bounds_int.resize(results.size());
205 r_bounds_valid_float.resize(results.size());
206 r_bounds_float.resize(results.size());
207 r_bounds_valid_set.resize(results.size());
208 r_bounds_set.resize(results.size());
209 defined.resize(results.size());
210 branches.resize(results.size());
211 for (unsigned int i = 0; i < results.size(); i++) {
212 allBranchesPar[i] = true;
213 r_bounds_valid_int[i] = true;
214 r_bounds_valid_float[i] = true;
215 r_bounds_valid_set[i] = true;
216 }
217
218 Ctx cmix;
219 cmix.b = C_MIX;
220 cmix.i = C_MIX;
221
222 for (int i = 0; i < ite->size(); i++) {
223 bool cond = true;
224 EE e_if;
225 if (ite->e_if(i)->isa<Call>() && ite->e_if(i)->cast<Call>()->id() == "mzn_in_root_context") {
226 e_if = EE(constants().boollit(ctx.b == C_ROOT), constants().lit_true);
227 } else {
228 e_if = flat_exp(env, cmix, ite->e_if(i), NULL, constants().var_true);
229 }
230 if (e_if.r()->type() == Type::parbool()) {
231 {
232 GCLock lock;
233 cond = eval_bool(env, e_if.r());
234 }
235 if (cond) {
236 if (allConditionsPar || conditions.size() == 0) {
237 // no var conditions before this one, so we can simply emit
238 // the then branch
239 return flat_exp(env, ctx, ite->e_then(i), r, b);
240 }
241 // had var conditions, so we have to take them into account
242 // and emit new conditional clause
243 Ctx cmix;
244 cmix.b = C_MIX;
245 cmix.i = C_MIX;
246 // add another condition and definedness variable
247 conditions.push_back(constants().lit_true);
248 for (unsigned int j = 0; j < results.size(); j++) {
249 EE ethen = flat_exp(env, cmix, e_then[j][i], NULL, NULL);
250 assert(ethen.b());
251 defined[j].push_back(ethen.b);
252 allDefined = allDefined && (ethen.b() == constants().lit_true);
253 branches[j].push_back(ethen.r);
254 if (ethen.r()->type().isvar()) {
255 allBranchesPar[j] = false;
256 }
257 }
258 break;
259 }
260 } else {
261 allConditionsPar = false;
262 // add current condition and definedness variable
263 conditions.push_back(e_if.r);
264
265 for (unsigned int j = 0; j < results.size(); j++) {
266 // flatten the then branch
267 EE ethen = flat_exp(env, cmix, e_then[j][i], NULL, NULL);
268
269 assert(ethen.b());
270 defined[j].push_back(ethen.b);
271 allDefined = allDefined && (ethen.b() == constants().lit_true);
272 branches[j].push_back(ethen.r);
273 if (ethen.r()->type().isvar()) {
274 allBranchesPar[j] = false;
275 }
276 }
277 }
278 // update bounds
279
280 if (cond) {
281 for (unsigned int j = 0; j < results.size(); j++) {
282 if (r_bounds_valid_int[j] && e_then[j][i]->type().isint()) {
283 GCLock lock;
284 IntBounds ib_then = compute_int_bounds(env, e_then[j][i]);
285 if (ib_then.valid) r_bounds_int[j].push_back(ib_then);
286 r_bounds_valid_int[j] = r_bounds_valid_int[j] && ib_then.valid;
287 } else if (r_bounds_valid_set[j] && e_then[j][i]->type().isintset()) {
288 GCLock lock;
289 IntSetVal* isv = compute_intset_bounds(env, e_then[j][i]);
290 if (isv) r_bounds_set[j].push_back(isv);
291 r_bounds_valid_set[j] = r_bounds_valid_set[j] && isv;
292 } else if (r_bounds_valid_float[j] && e_then[j][i]->type().isfloat()) {
293 GCLock lock;
294 FloatBounds fb_then = compute_float_bounds(env, e_then[j][i]);
295 if (fb_then.valid) r_bounds_float[j].push_back(fb_then);
296 r_bounds_valid_float[j] = r_bounds_valid_float[j] && fb_then.valid;
297 }
298 }
299 }
300 }
301
302 if (allConditionsPar) {
303 // no var condition, and all par conditions were false,
304 // so simply emit else branch
305 return flat_exp(env, ctx, ite->e_else(), r, b);
306 }
307
308 for (unsigned int j = 0; j < results.size(); j++) {
309 if (results[j] == NULL) {
310 // need to introduce new result variable
311 GCLock lock;
312 TypeInst* ti = new TypeInst(Location().introduce(), ite->type(), NULL);
313 results[j] = newVarDecl(env, Ctx(), ti, NULL, NULL, NULL);
314 }
315 }
316
317 if (conditions.back()() != constants().lit_true) {
318 // The last condition wasn't fixed to true, we need to look at the else branch
319 conditions.push_back(constants().lit_true);
320
321 for (unsigned int j = 0; j < results.size(); j++) {
322 VarDecl* nr = results[j];
323
324 // update bounds of result with bounds of else branch
325
326 if (r_bounds_valid_int[j] && e_else[j]->type().isint()) {
327 GCLock lock;
328 IntBounds ib_else = compute_int_bounds(env, e_else[j]);
329 if (ib_else.valid) {
330 r_bounds_int[j].push_back(ib_else);
331 IntVal lb = IntVal::infinity();
332 IntVal ub = -IntVal::infinity();
333 for (unsigned int i = 0; i < r_bounds_int[j].size(); i++) {
334 lb = std::min(lb, r_bounds_int[j][i].l);
335 ub = std::max(ub, r_bounds_int[j][i].u);
336 }
337 if (results[j]) {
338 IntBounds orig_r_bounds = compute_int_bounds(env, results[j]->id());
339 if (orig_r_bounds.valid) {
340 lb = std::max(lb, orig_r_bounds.l);
341 ub = std::min(ub, orig_r_bounds.u);
342 }
343 }
344 SetLit* r_dom = new SetLit(Location().introduce(), IntSetVal::a(lb, ub));
345 nr->ti()->domain(r_dom);
346 }
347 } else if (r_bounds_valid_set[j] && e_else[j]->type().isintset()) {
348 GCLock lock;
349 IntSetVal* isv_else = compute_intset_bounds(env, e_else[j]);
350 if (isv_else) {
351 IntSetVal* isv = isv_else;
352 for (unsigned int i = 0; i < r_bounds_set[j].size(); i++) {
353 IntSetRanges i0(isv);
354 IntSetRanges i1(r_bounds_set[j][i]);
355 Ranges::Union<IntVal, IntSetRanges, IntSetRanges> u(i0, i1);
356 isv = IntSetVal::ai(u);
357 }
358 if (results[j]) {
359 IntSetVal* orig_r_bounds = compute_intset_bounds(env, results[j]->id());
360 if (orig_r_bounds) {
361 IntSetRanges i0(isv);
362 IntSetRanges i1(orig_r_bounds);
363 Ranges::Inter<IntVal, IntSetRanges, IntSetRanges> inter(i0, i1);
364 isv = IntSetVal::ai(inter);
365 }
366 }
367 SetLit* r_dom = new SetLit(Location().introduce(), isv);
368 nr->ti()->domain(r_dom);
369 }
370 } else if (r_bounds_valid_float[j] && e_else[j]->type().isfloat()) {
371 GCLock lock;
372 FloatBounds fb_else = compute_float_bounds(env, e_else[j]);
373 if (fb_else.valid) {
374 FloatVal lb = fb_else.l;
375 FloatVal ub = fb_else.u;
376 for (unsigned int i = 0; i < r_bounds_float[j].size(); i++) {
377 lb = std::min(lb, r_bounds_float[j][i].l);
378 ub = std::max(ub, r_bounds_float[j][i].u);
379 }
380 if (results[j]) {
381 FloatBounds orig_r_bounds = compute_float_bounds(env, results[j]->id());
382 if (orig_r_bounds.valid) {
383 lb = std::max(lb, orig_r_bounds.l);
384 ub = std::min(ub, orig_r_bounds.u);
385 }
386 }
387 BinOp* r_dom =
388 new BinOp(Location().introduce(), FloatLit::a(lb), BOT_DOTDOT, FloatLit::a(ub));
389 r_dom->type(Type::parfloat(1));
390 nr->ti()->domain(r_dom);
391 }
392 }
393
394 // flatten else branch
395 EE eelse = flat_exp(env, cmix, e_else[j], NULL, NULL);
396 assert(eelse.b());
397 defined[j].push_back(eelse.b);
398 allDefined = allDefined && (eelse.b() == constants().lit_true);
399 branches[j].push_back(eelse.r);
400 if (eelse.r()->type().isvar()) {
401 allBranchesPar[j] = false;
402 }
403 }
404 }
405
406 // Create ite predicate calls
407 GCLock lock;
408 ArrayLit* al_cond = new ArrayLit(Location().introduce(), conditions);
409 al_cond->type(Type::varbool(1));
410 for (unsigned int j = 0; j < results.size(); j++) {
411 ArrayLit* al_branches = new ArrayLit(Location().introduce(), branches[j]);
412 Type branches_t = results[j]->type();
413 branches_t.dim(1);
414 branches_t.ti(allBranchesPar[j] ? Type::TI_PAR : Type::TI_VAR);
415 al_branches->type(branches_t);
416 Call* ite_pred = new Call(ite->loc().introduce(), ASTString("if_then_else"),
417 {al_cond, al_branches, results[j]->id()});
418 ite_pred->decl(env.model->matchFn(env, ite_pred, false));
419 ite_pred->type(Type::varbool());
420 (void)flat_exp(env, Ctx(), ite_pred, constants().var_true, constants().var_true);
421 }
422 EE ret;
423 if (noOtherBranches) {
424 ret.r = constants().var_true->id();
425 } else {
426 ret.r = results.back()->id();
427 }
428 if (allDefined) {
429 bind(env, ctx, b, constants().lit_true);
430 ret.b = constants().lit_true;
431 } else {
432 // Otherwise, constraint linking conditions, b and the definedness variables
433 if (b == NULL) {
434 CallStackItem _csi(env, new StringLit(Location().introduce(), "b"));
435 b = newVarDecl(env, Ctx(), new TypeInst(Location().introduce(), Type::varbool()), NULL, NULL,
436 NULL);
437 }
438 ret.b = b->id();
439
440 std::vector<Expression*> defined_conjunctions(ite->size() + 1);
441 for (unsigned int i = 0; i < ite->size(); i++) {
442 std::vector<Expression*> def_i;
443 for (unsigned int j = 0; j < defined.size(); j++) {
444 if (defined[j][i]() != constants().lit_true) {
445 def_i.push_back(defined[j][i]());
446 }
447 }
448 if (def_i.size() == 0) {
449 defined_conjunctions[i] = constants().lit_true;
450 } else if (def_i.size() == 1) {
451 defined_conjunctions[i] = def_i[0];
452 } else {
453 ArrayLit* al = new ArrayLit(Location().introduce(), def_i);
454 al->type(Type::varbool(1));
455 Call* forall = new Call(Location().introduce(), constants().ids.forall, {al});
456 forall->decl(env.model->matchFn(env, forall, false));
457 forall->type(forall->decl()->rtype(env, {al}, false));
458 defined_conjunctions[i] = forall;
459 }
460 }
461 ArrayLit* al_defined = new ArrayLit(Location().introduce(), defined_conjunctions);
462 al_defined->type(Type::varbool(1));
463 Call* ite_defined_pred = new Call(ite->loc().introduce(), ASTString("if_then_else_partiality"),
464 {al_cond, al_defined, b->id()});
465 ite_defined_pred->decl(env.model->matchFn(env, ite_defined_pred, false));
466 ite_defined_pred->type(Type::varbool());
467 (void)flat_exp(env, Ctx(), ite_defined_pred, constants().var_true, constants().var_true);
468 }
469
470 return ret;
471}
472
473} // namespace MiniZinc