to_delete / dreamcoder /domains /regex /regexPrimitives.py
Fraser-Greenlee
add dreamcoder codebase
e1c1753
raw
history blame contribute delete
14.5 kB
import sys
from dreamcoder.program import Primitive
from dreamcoder.grammar import Grammar
from dreamcoder.type import arrow, tpregex
from string import printable
try:
from pregex import pregex
except:
print("Failure to load pregex. This is only acceptable if using pypy",file=sys.stderr)
# evaluation to regular regex form. then I can unflatten using Luke's stuff.
def _kleene(x): return pregex.KleeneStar(x, p=0.25)
def _plus(x): return pregex.Plus(x, p=0.25)
def _maybe(x): return pregex.Maybe(x)
# maybe should be reversed#"(" + x + "|" + y + ")"
def _alt(x): return lambda y: pregex.Alt([x, y])
def _concat(x): return lambda y: pregex.Concat([x, y]) # "(" + x + y + ")"
#For sketch:
def _kleene_5(x): return pregex.KleeneStar(x)
def _plus_5(x): return pregex.Plus(x)
disallowed = [
("#", "hash"),
("!", "bang"),
("\"", "double_quote"),
("$", "dollar"),
("%", "percent"),
("&", "ampersand"),
("'", "single_quote"),
(")", "left_paren"),
("(", "right_paren"),
("*", "astrisk"),
("+", "plus"),
(",", "comma"),
("-", "dash"),
(".", "period"),
("/", "slash"),
(":", "colon"),
(";", "semicolon"),
("<", "less_than"),
("=", "equal"),
(">", "greater_than"),
("?", "question_mark"),
("@", "at"),
("[", "left_bracket"),
("\\", "backslash"),
("]", "right_bracket"),
("^", "carrot"),
("_", "underscore"),
("`", "backtick"),
("|", "bar"),
("}", "right_brace"),
("{", "left_brace"),
("~", "tilde"),
(" ", "space"),
("\t", "tab")
]
disallowed_list = [char for char, _ in disallowed]
class PRC(): #PregexContinuation
def __init__(self, f, arity=0, args=[]):
self.f = f
self.arity = arity
self.args = args
def __call__(self, pre):
if self.arity == len(self.args):
if self.arity == 0: return pregex.Concat([self.f, pre])
elif self.arity == 1: return pregex.Concat([self.f(*self.args), pre])
else: return pregex.Concat([self.f(self.args), pre]) #this line is bad, need brackets around input to f if f is Alt
else: return PRC(self.f, self.arity, args=self.args+[pre(pregex.String(""))])
def concatPrimitives():
return [Primitive("string_" + i, arrow(tpregex, tpregex), PRC(pregex.String(i))) for i in printable[:-4] if i not in disallowed_list
] + [
Primitive("string_" + name, arrow(tpregex, tpregex), PRC(pregex.String(char))) for char, name in disallowed
] + [
Primitive("r_dot", arrow(tpregex, tpregex), PRC(pregex.dot)),
Primitive("r_d", arrow(tpregex, tpregex), PRC(pregex.d)),
Primitive("r_s", arrow(tpregex, tpregex), PRC(pregex.s)),
Primitive("r_w", arrow(tpregex, tpregex), PRC(pregex.w)),
Primitive("r_l", arrow(tpregex, tpregex), PRC(pregex.l)),
Primitive("r_u", arrow(tpregex, tpregex), PRC(pregex.u)),
#todo
Primitive("r_kleene", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.KleeneStar,1)),
Primitive("r_plus", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Plus,1)),
Primitive("r_maybe", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Maybe,1)),
Primitive("r_alt", arrow(arrow(tpregex, tpregex) , arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Alt, 2)),
]
def strConstConcatPrimitives():
return [Primitive("string_" + i, arrow(tpregex, tpregex), PRC(pregex.String(i))) for i in printable[:-4] if i not in disallowed_list
] + [
Primitive("string_" + name, arrow(tpregex, tpregex), PRC(pregex.String(char))) for char, name in disallowed
] + [
Primitive("r_dot", arrow(tpregex, tpregex), PRC(pregex.dot)),
Primitive("r_d", arrow(tpregex, tpregex), PRC(pregex.d)),
Primitive("r_s", arrow(tpregex, tpregex), PRC(pregex.s)),
Primitive("r_w", arrow(tpregex, tpregex), PRC(pregex.w)),
Primitive("r_l", arrow(tpregex, tpregex), PRC(pregex.l)),
Primitive("r_u", arrow(tpregex, tpregex), PRC(pregex.u)),
#todo
Primitive("r_kleene", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.KleeneStar,1)),
Primitive("r_plus", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Plus,1)),
Primitive("r_maybe", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Maybe,1)),
Primitive("r_alt", arrow(arrow(tpregex, tpregex) , arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Alt, 2)),
] + [
Primitive("r_const", arrow(tpregex, tpregex), None)
]
def reducedConcatPrimitives():
#uses strConcat!!
#[Primitive("empty_string", arrow(tpregex, tpregex), PRC(pregex.String("")))
#] + [
return [Primitive("string_" + i, arrow(tpregex, tpregex), PRC(pregex.String(i))) for i in printable[:-4] if i not in disallowed_list
] + [
Primitive("string_" + name, arrow(tpregex, tpregex), PRC(pregex.String(char))) for char, name in disallowed
] + [
Primitive("r_dot", arrow(tpregex, tpregex), PRC(pregex.dot)),
Primitive("r_d", arrow(tpregex, tpregex), PRC(pregex.d)),
Primitive("r_s", arrow(tpregex, tpregex), PRC(pregex.s)),
#Primitive("r_w", arrow(tpregex, tpregex), PRC(pregex.w)),
Primitive("r_l", arrow(tpregex, tpregex), PRC(pregex.l)),
Primitive("r_u", arrow(tpregex, tpregex), PRC(pregex.u)),
#todo
Primitive("r_kleene", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.KleeneStar,1)),
#Primitive("r_plus", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Plus,1)),
Primitive("r_maybe", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Maybe,1)),
Primitive("r_alt", arrow(arrow(tpregex, tpregex) , arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Alt, 2)),
] + [
Primitive("r_const", arrow(tpregex, tpregex), None)
]
def sketchPrimitives():
return [Primitive("string_" + i, tpregex, pregex.String(i)) for i in printable[:-4] if i not in disallowed_list
] + [
Primitive("string_" + name, tpregex, pregex.String(char)) for char, name in disallowed
] + [
Primitive("r_dot", tpregex, pregex.dot),
Primitive("r_d", tpregex, pregex.d),
Primitive("r_s", tpregex, pregex.s),
Primitive("r_w", tpregex, pregex.w),
Primitive("r_l", tpregex, pregex.l),
Primitive("r_u", tpregex, pregex.u),
Primitive("r_kleene", arrow(tpregex, tpregex), _kleene_5),
Primitive("r_plus", arrow(tpregex, tpregex), _plus_5),
Primitive("r_maybe", arrow(tpregex, tpregex), _maybe),
Primitive("r_alt", arrow(tpregex, tpregex, tpregex), _alt),
Primitive("r_concat", arrow(tpregex, tpregex, tpregex), _concat),
]
def basePrimitives():
return [Primitive("string_" + i, tpregex, pregex.String(i)) for i in printable[:-4] if i not in disallowed_list
] + [
Primitive("string_" + name, tpregex, pregex.String(char)) for char, name in disallowed
] + [
Primitive("r_dot", tpregex, pregex.dot),
Primitive("r_d", tpregex, pregex.d),
Primitive("r_s", tpregex, pregex.s),
Primitive("r_w", tpregex, pregex.w),
Primitive("r_l", tpregex, pregex.l),
Primitive("r_u", tpregex, pregex.u),
Primitive("r_kleene", arrow(tpregex, tpregex), _kleene),
Primitive("r_plus", arrow(tpregex, tpregex), _plus),
Primitive("r_maybe", arrow(tpregex, tpregex), _maybe),
Primitive("r_alt", arrow(tpregex, tpregex, tpregex), _alt),
Primitive("r_concat", arrow(tpregex, tpregex, tpregex), _concat),
]
def altPrimitives():
return [
Primitive("empty_string", tpregex, pregex.String(""))
] + [
Primitive("string_" + i, tpregex, pregex.String(i)) for i in printable[:-4] if i not in disallowed_list
] + [
Primitive("string_" + name, tpregex, pregex.String(char)) for char, name in disallowed
] + [
Primitive("r_dot", tpregex, pregex.dot),
Primitive("r_d", tpregex, pregex.d),
Primitive("r_s", tpregex, pregex.s),
Primitive("r_w", tpregex, pregex.w),
Primitive("r_l", tpregex, pregex.l),
Primitive("r_u", tpregex, pregex.u),
Primitive("r_kleene", arrow(tpregex, tpregex), _kleene),
#Primitive("r_plus", arrow(tpregex, tpregex), _plus),
Primitive("r_maybe", arrow(tpregex, tpregex), _maybe),
Primitive("r_alt", arrow(tpregex, tpregex, tpregex), _alt),
Primitive("r_concat", arrow(tpregex, tpregex, tpregex), _concat),
]
def alt2Primitives():
return [
Primitive("empty_string", tpregex, pregex.String(""))
] + [
Primitive("string_" + i, tpregex, pregex.String(i)) for i in printable[:-4] if i not in disallowed_list
] + [
Primitive("string_" + name, tpregex, pregex.String(char)) for char, name in disallowed
] + [
Primitive("r_dot", tpregex, pregex.dot),
Primitive("r_d", tpregex, pregex.d),
Primitive("r_s", tpregex, pregex.s),
Primitive("r_w", tpregex, pregex.w),
Primitive("r_l", tpregex, pregex.l),
Primitive("r_u", tpregex, pregex.u),
Primitive("r_kleene", arrow(tpregex, tpregex), _kleene),
#Primitive("r_plus", arrow(tpregex, tpregex), _plus),
#Primitive("r_maybe", arrow(tpregex, tpregex), _maybe),
Primitive("r_alt", arrow(tpregex, tpregex, tpregex), _alt),
Primitive("r_concat", arrow(tpregex, tpregex, tpregex), _concat),
]
def easyWordsPrimitives():
return [
Primitive("string_" + i, tpregex, pregex.String(i)) for i in printable[10:62] if i not in disallowed_list
] + [
Primitive("r_d", tpregex, pregex.d),
Primitive("r_s", tpregex, pregex.s),
#Primitive("r_w", tpregex, pregex.w),
Primitive("r_l", tpregex, pregex.l),
Primitive("r_u", tpregex, pregex.u),
Primitive("r_kleene", arrow(tpregex, tpregex), _kleene),
Primitive("r_plus", arrow(tpregex, tpregex), _plus),
Primitive("r_maybe", arrow(tpregex, tpregex), _maybe),
Primitive("r_alt", arrow(tpregex, tpregex, tpregex), _alt),
Primitive("r_concat", arrow(tpregex, tpregex, tpregex), _concat),
]
#def _wrapper(x): return lambda y: y
#specials = [".","*","+","?","|"]
"""
>>> import pregex as pre
>>> abc = pre.CharacterClass("abc", [0.1, 0.1, 0.8], name="MyConcept")
>>> abc.sample()
'b'
>>> abc.sample()
'c'
>>> abc.sample()
'c'
>>> abc.match("c")
-0.2231435513142097
>>> abc.match("a")
-2.3025850929940455
>>> abc
MyConcept
>>> x = pre.KleeneStar(abc)
>>> x.match("aabbac")
-16.58809928020405
>>> x.sample()
''
>>> x.sample()
''
>>> x.sample()
'cbcacc'
>>> x
(KleeneStar 0.5 MyConcept)
>>> str(x)
'MyConcept*'
"""
def emp_dot(corpus): return pregex.CharacterClass(printable[:-4], emp_distro_from_corpus(corpus, printable[:-4]), name=".")
def emp_d(corpus): return pregex.CharacterClass(printable[:10], emp_distro_from_corpus(corpus, printable[:10]), name="\\d")
#emp_s = pre.CharacterClass(slist, [], name="emp\\s") #may want to forgo this one.
def emp_dot_no_letter(corpus): return pregex.CharacterClass(printable[:10]+printable[62:], emp_distro_from_corpus(corpus, printable[:10]+printable[62:]), name=".")
def emp_w(corpus): return pregex.CharacterClass(printable[:62], emp_distro_from_corpus(corpus, printable[:62]), name="\\w")
def emp_l(corpus): return pregex.CharacterClass(printable[10:36], emp_distro_from_corpus(corpus, printable[10:36]), name="\\l")
def emp_u(corpus): return pregex.CharacterClass(printable[36:62], emp_distro_from_corpus(corpus, printable[36:62]), name="\\u")
def emp_distro_from_corpus(corpus, char_list):
from collections import Counter
c = Counter(char for task in corpus for example in task.examples for string in example[1] for char in string)
n = sum(c[char] for char in char_list)
return [c[char]/n for char in char_list]
def matchEmpericalPrimitives(corpus):
return lambda: [
Primitive("empty_string", tpregex, pregex.String(""))
] + [
Primitive("string_" + i, tpregex, pregex.String(i)) for i in printable[:-4] if i not in disallowed_list
] + [
Primitive("string_" + name, tpregex, pregex.String(char)) for char, name in disallowed
] + [
Primitive("r_dot", tpregex, emp_dot(corpus) ),
Primitive("r_d", tpregex, emp_d(corpus) ),
Primitive("r_s", tpregex, pregex.s),
Primitive("r_w", tpregex, emp_w(corpus) ),
Primitive("r_l", tpregex, emp_l(corpus) ),
Primitive("r_u", tpregex, emp_u(corpus) ),
Primitive("r_kleene", arrow(tpregex, tpregex), _kleene),
#Primitive("r_plus", arrow(tpregex, tpregex), _plus),
Primitive("r_maybe", arrow(tpregex, tpregex), _maybe),
Primitive("r_alt", arrow(tpregex, tpregex, tpregex), _alt),
Primitive("r_concat", arrow(tpregex, tpregex, tpregex), _concat),
]
def matchEmpericalNoLetterPrimitives(corpus):
return lambda: [
Primitive("empty_string", tpregex, pregex.String(""))
] + [
Primitive("string_" + i, tpregex, pregex.String(i)) for i in printable[:-4] if i not in disallowed_list + list(printable[10:62])
] + [
Primitive("string_" + name, tpregex, pregex.String(char)) for char, name in disallowed
] + [
Primitive("r_dot", tpregex, emp_dot_no_letter(corpus) ),
Primitive("r_d", tpregex, emp_d(corpus) ),
Primitive("r_s", tpregex, pregex.s),
Primitive("r_kleene", arrow(tpregex, tpregex), _kleene),
#Primitive("r_plus", arrow(tpregex, tpregex), _plus),
#Primitive("r_maybe", arrow(tpregex, tpregex), _maybe),
Primitive("r_alt", arrow(tpregex, tpregex, tpregex), _alt),
Primitive("r_concat", arrow(tpregex, tpregex, tpregex), _concat),
]
if __name__=='__main__':
concatPrimitives()
from dreamcoder.program import Program
p=Program.parse("(lambda (r_kleene (lambda (r_maybe (lambda (string_x $0)) $0)) $0))")
print(p)
print(p.runWithArguments([pregex.String("")]))
prims = concatPrimitives()
g = Grammar.uniform(prims)
for i in range(100):
prog = g.sample(arrow(tpregex,tpregex))
preg = prog.runWithArguments([pregex.String("")])
print("preg:", preg.__repr__())
print("sample:", preg.sample())