"""新規関数と新規制約に対応するためのフック関数を提供.
"""
[ドキュメント]
def defaultFunctionHook(function, encoder):
"""新規関数に対応するためのフックとしてデフォールトで設定されている関数.
以下に対応している.
* ``["div", X, n]``
* ``["mod", X, n]``
* ``["abs", X]``
* ``["min", X1, ..., Xn]``
* ``["max", X1, ..., Xn]``
* ``["if", A, X, Y]``
Args:
function (list): 制約充足問題中の式.
encoder (Encoder): Encoderインスタンス.
Returns:
引数の function そのまま,または function を変換した式.
"""
from .csp import LogEncoder
def getBound(xx):
bounds = [ encoder.getBound(x) for x in xx ]
lb = min(b[0] for b in bounds)
ub = max(b[1] for b in bounds)
return (lb, ub)
def decomposeAbs(x):
(lb, ub) = encoder.getBound(x)
z = Var()
encoder.put(["int", z, 0, max(abs(lb), abs(ub))])
encoder.put(["imp", ["<", x, 0], ["==", z, ["-", x]]])
encoder.put(["imp", [">=", x, 0], ["==", z, x]])
return z
def decomposeMin(xx):
(lb, ub) = getBound(xx)
z = Var()
c1 = [ ["<=", z, x] for x in xx ]
c2 = [ [">=", z, x] for x in xx ]
encoder.put(["int", z, lb, ub])
encoder.put(["and", *c1, ["or", *c2]])
return z
def decomposeMax(xx):
(lb, ub) = getBound(xx)
z = Var()
c1 = [ [">=", z, x] for x in xx ]
c2 = [ ["<=", z, x] for x in xx ]
encoder.put(["int", z, lb, ub])
encoder.put(["and", *c1, ["or", *c2]])
return z
def decomposeIf(c, x, y):
(lb, ub) = getBound([x, y])
(p, z) = (Bool(), Var())
encoder.put(["equ", p, c])
encoder.put(["int", z, lb, ub])
encoder.put(["imp", p, ["==", z, x]])
encoder.put(["imp", ~p, ["==", z, y]])
return z
def decomposeDivMod(x, n):
if not isinstance(n, int) or n <= 0:
raise CspsatException(f"divあるいはmodの第2引数が正の整数定数でない: {n}")
(lb, ub) = encoder.getBound(x)
(q, r) = (Var(), Var())
encoder.put(["int", q, lb//n, ub//n])
encoder.put(["int", r, 0, n-1])
encoder.put(["==", x, ["+", ["*", n, q], r]])
return (q, r)
def decomposeBit(x, k):
if not isinstance(encoder, LogEncoder):
raise CspsatException("EncoderがLogEncoderでない")
if encoder.intLb(x) != 0:
raise CspsatException(f"変数の下限が0でない: {x}")
return encoder.varBitK(x, k)
match function:
case ["div", x, n] | ["//", x, n]:
function = decomposeDivMod(x, n)[0]
case ["mod", x, n] | ["%", x, n]:
function = decomposeDivMod(x, n)[1]
case ["abs", x]:
function = decomposeAbs(x)
case ["min", *xx]:
function = decomposeMin(xx)
case ["max", *xx]:
function = decomposeMax(xx)
case ["if", c, x, y]:
function = decomposeIf(c, x, y)
case ["bit", x, k]:
function = decomposeBit(x, k)
return function
[ドキュメント]
def defaultConstraintHook(constraint, encoder):
"""新規制約に対応するためのフックとしてデフォールトで設定されている関数.
以下に対応している.
* ``["alldifferent", X1, ..., Xn]``
* ``["lexCmp", cmp, [X1,...Xn], [Y1,...,Yn]]`` (cmpは"==", "!=", "<=", "<", ">=", ">")
* ``["mulCmp", cmp, X, Y, Z]`` (cmpは"==", "!=", "<=", "<", ">=", ">")
* ``["powCmp", cmp, X, n, Z]`` (cmpは"==", "!=", "<=", "<", ">=", ">")
* ``["bits", [X1,...Xn], X]``
* ``["bit", X, i]``
Args:
constraint (list): 制約充足問題中の制約.
encoder (Encoder): Encoderインスタンス.
Returns:
引数の constraint そのまま,または constraint を変換した制約.
"""
from .csp import LogEncoder
def decomposeAlldifferent(xx):
bounds = [ encoder.getBound(x) for x in xx ]
lb = min(b[0] for b in bounds)
ub = max(b[1] for b in bounds)
d = ub - lb + 1
m = d - len(xx)
if m < 0:
yield FALSE
return
if len(xx) <= 2:
for (x1,x2) in itertools.combinations(xx, 2):
yield ["!=", x1, x2]
return
t = Bool()
if m > 0:
tt = [ t(j) for j in range(lb, ub+1) ]
yield ["eqK", tt, m]
for k in range(lb, ub+1):
p = Bool()
for (i,x) in enumerate(xx):
yield ["equ", p(i), ["==", x, k]]
pp = [ p(i) for i in range(len(xx)) ]
if m > 0:
pp.append(t(k))
yield ["eqK", pp, 1]
def _fillZero(xx, yy):
if len(xx) < len(yy):
xx = xx + [0] * (len(yy)-len(xx))
elif len(xx) > len(yy):
yy = yy + [0] * (len(xx)-len(yy))
return (xx, yy)
def decomposeLexEq(xx, yy):
(xx, yy) = _fillZero(xx, yy)
for (i,x) in enumerate(xx):
yield ["==", x, yy[i]]
def decomposeLexLe(xx, yy, less=False):
(xx, yy) = _fillZero(xx, yy)
n = len(xx)
a = Bool()
yield a(-1)
for i in range(n):
yield ["or", ~a(i-1), ["<=", ["+", xx[i], ~a(i)], yy[i]]]
yield ~a(n-1) if less else a(n-1)
def decomposeMulCmp(cmp, x, y, z):
if not isinstance(encoder, LogEncoder):
raise CspsatException("EncoderがLogEncoderでない")
binEqu = BinaryEquation(encoder)
binEqu.addMul(x, y)
binEqu.add(z, a=-1)
yield from [ ["or", *c] for c in binEqu.cmp0(cmp) ]
def decomposePowCmp(cmp, x, n, z):
if not isinstance(encoder, LogEncoder):
raise CspsatException("EncoderがLogEncoderでない")
binEqu = BinaryEquation(encoder)
binEqu.addPower(x, n)
binEqu.add(z, a=-1)
yield from [ ["or", *c] for c in binEqu.cmp0(cmp) ]
def decomposeBits(xx, x):
if not isinstance(encoder, LogEncoder):
raise CspsatException("EncoderがLogEncoderでない")
if encoder.intLb(x) != 0:
raise CspsatException(f"変数の下限が0でない: {x}")
yield from [ ["or", *c] for c in Binary.eq(xx, encoder.getBools(x)) ]
def decomposeBit(x, k):
if not isinstance(encoder, LogEncoder):
raise CspsatException("EncoderがLogEncoderでない")
if encoder.intLb(x) != 0:
raise CspsatException(f"変数の下限が0でない: {x}")
yield encoder.varBitK(x, k)
match constraint:
case ["alldifferent", *args]:
cs = decomposeAlldifferent(args)
constraint = ["and", *cs]
case ["lexCmp", "==", xx, yy]:
cs = decomposeLexEq(xx, yy)
constraint = ["and", *cs]
case ["lexCmp", "!=", xx, yy]:
cs = decomposeLexEq(xx, yy)
constraint = ["not", ["and", *cs]]
case ["lexCmp", "<=", xx, yy]:
cs = decomposeLexLe(xx, yy)
constraint = ["and", *cs]
case ["lexCmp", "<", xx, yy]:
cs = decomposeLexLe(xx, yy, less=True)
constraint = ["and", *cs]
case ["lexCmp", ">=", xx, yy]:
cs = decomposeLexLe(yy, xx)
constraint = ["and", *cs]
case ["lexCmp", ">", xx, yy]:
cs = decomposeLexLe(yy, xx, less=True)
constraint = ["and", *cs]
case ["mulCmp", cmp, x, y, z]:
cs = decomposeMulCmp(cmp, x, y, z)
constraint = ["and", *cs]
case ["powCmp", cmp, x, n, z]:
cs = decomposePowCmp(cmp, x, n, z)
constraint = ["and", *cs]
case ["bits", xx, x]:
cs = decomposeBits(xx, x)
constraint = ["and", *cs]
case ["bit", x, k]:
cs = decomposeBit(x, k)
constraint = ["and", *cs]
return constraint
import itertools
from .util import CspsatException, Bool, FALSE, Var, Binary, BinaryEquation