参考python标准库的AST实现。
语法制导翻译也参考标准库,比如参考 ast.NodeVisitor 和 iter_fields来将Space转换成TPE、HEBO等其他超参库的搜索空间定义。
#!/usr/bin/env python
# -*- coding:utf8 -*-
"""
Author: zhaopenghao
Create Time: 2020/5/7 上午10:47
"""
import copy
import itertools
import numpy as np
g_vars = {}
def parse_ast(expression):
if isinstance(expression, list):
return ListAst(expression)
elif isinstance(expression, dict):
if 'htype' in expression:
htype = expression['htype']
if htype == 'choice':
var_ast = ChoiceAst(expression)
elif htype == 'randint':
var_ast = RandintAst(expression)
elif htype == 'uniform':
var_ast = UniformAst(expression)
elif htype == 'loguniform':
var_ast = LoguniformAst(expression)
else:
raise ValueError("Unsupported htype: {}, in {}".format(htype, expression))
g_vars[var_ast.name] = var_ast
return var_ast
else:
return DictAst(expression)
else:
assert isinstance(expression, (int, float, str)), "Unsupported expression: {}".format(expression)
return PrimitiveAst(expression)
class Ast:
def __init__(self, expr):
self.expr = expr
self.name = '' # unique generator
self._is_random = None
self._is_grid = None
def sample(self, var2meta:'dict of var_name->var_meta'=None) -> 'tuple: config, var_meta_dict':
pass
def is_random(self): pass
def is_grid(self): pass
def grid(self) -> 'tuple: list of conf, list of var_meta_dict': pass
# def is_var(self): pass
class DictAst(Ast):
def __init__(self, expr):
super().__init__(expr)
self.dict = {} # str->Ast
for k, v in expr.items():
self.dict[k] = parse_ast(v)
def sample(self, var2meta:'dict of var_name->var_meta'=None) -> 'tuple: config, var_meta_dict':
conf = {}
vmeta = {}
for k, v in self.dict.items():
c, vm = v.sample(var2meta)
conf[k] = c
vmeta.update(vm)
return conf, vmeta
def is_random(self):
if self._is_random is not None:
return self._is_random
self._is_random = False
for k, v in self.dict.items():
if v.is_random():
self._is_random = True
break
return self._is_random
def is_grid(self):
if self._is_grid is not None:
return self._is_grid
self._is_grid = True
for k, v in self.dict.items():
if not v.is_grid():
self._is_grid = False
break
return self._is_grid
def grid(self) -> 'tuple: list of conf, list of var_meta_dict':
keys = list(self.dict.keys())
keys_confs = [] # key1_confs, key2_confs, ...
keys_metas = [] # key1_metas, key2_metas, ...
for key in keys:
key_confs, key_metas = self.dict[key].grid()
keys_confs.append(key_confs)
keys_metas.append(key_metas)
confs = []
for conf_list in itertools.product(keys_confs):
# key1_conf1, key2_conf1, ...
# key1_conf1, key2_conf2, ...
conf = {k: v for k, v in zip(keys, conf_list)}
confs.append(conf)
metas = []
for meta_list in itertools.product(keys_metas):
# key1_meta1, key2_meta2, ...
var_meta_dict = {}
for sub_meta in meta_list:
var_meta_dict.update(sub_meta)
metas.append(var_meta_dict)
return confs, metas
class ListAst(Ast):
def __init__(self, expr):
self.list = [] # Ast, ...
for v in expr:
self.list.append(parse_ast(v))
def sample(self, var2meta:'dict of var_name->var_meta'=None) -> 'tuple: config, var_meta_dict':
conf = []
vmeta = {}
for v in self.list:
c, vm = v.sample(var2meta)
conf.append(c)
vmeta.update(vm)
return conf, vmeta
def grid(self) -> 'tuple: list of conf, list of var_meta_dict':
confs_list = [] # conf1s, conf2s, ...
metas_list = [] # meta1s, meta2s, ...
for v in self.list:
subconfs, submetas = v.grid()
confs_list.append(subconfs)
metas_list.append(submetas)
confs = []
metas = []
for conf_list in itertools.product(confs_list):
# conf1_1, conf2_1, ...
conf = list[conf_list]
confs.append(conf)
for meta_list in itertools.product(metas_list):
# meta1_1, meta2_1, ...
var_meta_dict = {}
for submeta in meta_list:
var_meta_dict.update(submeta)
metas.append(var_meta_dict)
return confs, metas
class PrimitiveAst(Ast):
def __init__(self, expr):
self.value = expr # str, float, int
self._is_random = False
def is_random(self):
return self._is_random
def sample(self, var2meta:'dict of var_name->var_meta'=None) -> 'tuple: config, var_meta_dict':
return self.value, {}
def grid(self) -> 'tuple: list of conf, list of var_meta_dict':
return [self.value], [{}]
class VariableAst(Ast):
def __init__(self, expr):
super().__init__(expr)
self._is_random = True
self._is_conditional = False
self._is_grid = False
def is_random(self):
return self._is_random
def is_conditional(self):
return self._is_conditional
def is_grid(self):
self._is_grid
# def is_noncond_random(self): pass
def perturb(self, this_meta) -> 'that_meta': pass
class ChoiceAst(VariableAst):
def __init__(self, expr):
super().__init__(expr)
self.values = [] # Ast, Ast
for v in expr['value']:
self.values.append(parse_ast(v))
self._is_conditional = None
self._is_grid = True
def perturb(self, this_meta) -> 'that_meta':
idx = this_meta
if np.random.rand() > 0.5:
idx = max(0, idx-1)
else:
idx = min(len(self.values)-1, idx+1)
return idx
def sample(self, var2meta:'dict of var_name->var_meta'=None) -> 'tuple: config, var_meta_dict':
if self.name in var2meta:
idx = var2meta[self.name]
else:
idx = np.random.randint(0, len(self.values)) # random
vmeta = {self.name: idx}
conf, vm = self.values[idx].sample(var2meta)
vmeta.update(vm)
return conf, vm
def is_conditional(self):
if self._is_conditional is not None:
return self._is_conditional
self._is_conditional = False
for v in self.values:
if v.is_random():
self._is_conditional = True
break
return self._is_conditional
def grid(self):
confs = []
vmetas = []
for idx in range(len(self.values)):
sub_confs, sub_metas = self.values[idx].grid()
confs.extend(sub_confs)
vmetas.extend(sub_metas)
return confs, vmetas
class RandintAst(VariableAst):
def __init__(self, expr):
self.low, self.high = expr['low'], expr['high'] # 0, 10
self._is_grid = True
def sample(self, var2meta:'dict of var_name->var_meta'=None) -> 'tuple: config, var_meta_dict':
if self.name in var2meta:
value = var2meta[self.name]
else:
value = np.random.randint(self.low, self.high)
conf = value
vmeta = {self.name: value}
return conf, vmeta
def grid(self) -> 'tuple: list of conf, list of var_meta_dict':
confs = []
metas = []
for value in range(self.low, self.high):
confs.append(value)
metas.append({self.name: value})
return confs, metas
class UniformAst(VariableAst):
def __init__(self, expr):
self.low, self.high = expr['low'], expr['high'] # 0., 1.
class LoguniformAst(VariableAst):
def __init__(self, expr):
self.low, self.high = expr['low'], expr['high'] # 1e-5, 1e-2
self.base = expr.get('base', 10.)
class Space:
def __init__(self):
self.space_str = '' # json str
self.space_raw = {} # after json load to python structure
global g_vars
g_vars = {}
self.space_ast = parse_ast(self.space_raw) #
self.name2var = g_vars # name->var_ast
def sample(self, var2meta=None) -> 'tuple: config, var_meta_dict':
return self.space_ast.sample(var2meta)
def perturb(self, var2meta) -> 'tuple: config, var_meta_dict':
new_var2meta = copy.deepcopy(var2meta)
metas_to_perturb = var2meta # filter
for vname, vmeta in metas_to_perturb.items():
new_var2meta[vname] = self.name2var[vname].perturb(vmeta)
return self.space_ast.sample(new_var2meta)
def grid(self):
assert self.space_ast.is_grid()
confs, metas = self.space_ast.grid()
return confs, metas
#!/usr/bin/env python
# -*- coding:utf8 -*-
"""
Author: zhaopenghao
Create Time: 2020/5/19 上午10:55
"""
import json
import numpy as np
from rudder_autosearch.space.space import Space
from rudder_autosearch.space.space import Ast
def iter_fields(node):
"""
Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
that is present on *node*.
"""
for field in node._fields:
try:
yield field, getattr(node, field)
except AttributeError:
pass
class NodeVisitor(object):
"""refers to standard library `ast.NodeVisitor`"""
def visit(self, node):
"""Visit a node."""
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
return visitor(node)
def generic_visit(self, node):
"""Called if no explicit visitor function exists for a node."""
for field, value in iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, Ast):
self.visit(item)
elif isinstance(value, dict):
for key, item in value.items():
if isinstance(item, Ast):
self.visit(item)
elif isinstance(value, Ast):
self.visit(value)
def space2hpo_space(self, space):
"""space2hpo_space"""
raise NotImplementedError
@classmethod
def write(cls, text):
"""write"""
cls.result += text
@classmethod
def clear(cls):
"""clear"""
cls.result = ""
def visit_DictAst_public_one(node_visitor, node):
"""visit_DictAst_public_one"""
node_visitor.write('{')
for k, v in node.dict.items():
node_visitor.write('"{}": '.format(k))
node_visitor.visit(v)
node_visitor.write(',')
node_visitor.write('}')
def visit_ListAst_public_one(node_visitor, node):
"""visit_ListAst_public_one"""
node_visitor.write('[')
for v in node.list:
node_visitor.visit(v)
node_visitor.write(',')
node_visitor.write(']')
def visit_PrimitiveAst_public_one(node_visitor, node):
"""visit_PrimitiveAst_public_one"""
value = node.value
if isinstance(value, str):
node_visitor.write('"{}"'.format(value))
else:
node_visitor.write('{}'.format(value))
class HpoVisitor(NodeVisitor):
"""HyperOpt space code generator, ref to `astunparse.unparser.Unparser`"""
result = ""
def __init__(self):
self.clear()
def space2hpo_space(self, space):
"""hpo_space"""
from hyperopt import hp
from hyperopt.pyll import scope
self.clear()
self.visit(space.space_ast)
hps = eval(self.result)
return hps
def visit_DictAst(self, node):
"""visit_DictAst"""
visit_DictAst_public_one(self, node)
def visit_ListAst(self, node):
"""visit_DictAst"""
visit_ListAst_public_one(self, node)
def visit_PrimitiveAst(self, node):
"""visit_PrimitiveAst"""
visit_PrimitiveAst_public_one(self, node)
def visit_ChoiceAst(self, node):
"""visit_ChoiceAst"""
self.write('hp.choice("{}", ['.format(node.name))
for v in node.values:
self.visit(v)
self.write(',')
self.write('])')
def visit_RandintAst(self, node):
"""visit_RandintAst"""
# self.write('hp.randint("{}", {}, {}, size=None)'.format(node.name, node.low, node.high + 1))
self.write('hp.randint("{}", {}, {})'.format(node.name, node.low, node.high + 1))
def visit_UniformAst(self, node):
"""visit_UniformAst"""
self.write('hp.uniform("{}", {}, {})'.format(node.name, node.low, node.high))
def visit_LoguniformAst(self, node):
"""visit_LoguniformAst"""
exp_low = np.log(node.low)
exp_high = np.log(node.high)
self.write('hp.loguniform("{}", {}, {})'.format(node.name, exp_low, exp_high))
def visit_QuniformAst(self, node):
"""visit_QuniformAst"""
low, high, step, dtype = node.low, node.high, node.step, node.dtype
new_low = 0
new_high = high - low
# hp.quniform(label, low, high, q)
# Returns a value like round(uniform(low, high) / q) * q. From <http://hyperopt.github.io/hyperopt/getting-started/search_spaces/>
# So offset low to 0 first
value = '{} + hp.quniform("{}", {}, {}, {})'.format(low, node.name, new_low, new_high, step)
if dtype == "int":
value = 'scope.int({})'.format(value)
self.write(value)