参考python标准库的AST实现。

语法制导翻译也参考标准库,比如参考 ast.NodeVisitoriter_fields来将Space转换成TPE、HEBO等其他超参库的搜索空间定义。

抽象语法树

#!/usr/bin/env python
# -*- coding:utf8 -*-
"""
 Author: zhaopenghao
 Create Time: 2020/5/7 上午10:47
"""
import copy
import itertools
import numpy as np

g_vars = {}

def parse_ast(expression):
    if isinstance(expression, list):
        return ListAst(expression)
    elif isinstance(expression, dict):
        if 'htype' in expression:
            htype = expression['htype']
            if htype == 'choice':
                var_ast = ChoiceAst(expression)
            elif htype == 'randint':
                var_ast = RandintAst(expression)
            elif htype == 'uniform':
                var_ast = UniformAst(expression)
            elif htype == 'loguniform':
                var_ast = LoguniformAst(expression)
            else:
                raise ValueError("Unsupported htype: {}, in {}".format(htype, expression))
            g_vars[var_ast.name] = var_ast
            return var_ast
        else:
            return DictAst(expression)
    else:
        assert isinstance(expression, (int, float, str)), "Unsupported expression: {}".format(expression)
        return PrimitiveAst(expression)

class Ast:
    def __init__(self, expr):
        self.expr = expr
        self.name = ''  # unique generator
        self._is_random = None
        self._is_grid = None

    def sample(self, var2meta:'dict of var_name->var_meta'=None) -> 'tuple: config, var_meta_dict':
        pass
    def is_random(self): pass
    def is_grid(self): pass
    def grid(self) -> 'tuple: list of conf, list of var_meta_dict': pass
    # def is_var(self): pass

class DictAst(Ast):
    def __init__(self, expr):
        super().__init__(expr)
        self.dict = {}  # str->Ast
        for k, v in expr.items():
            self.dict[k] = parse_ast(v)
    def sample(self, var2meta:'dict of var_name->var_meta'=None) -> 'tuple: config, var_meta_dict':
        conf = {}
        vmeta = {}
        for k, v in self.dict.items():
            c, vm = v.sample(var2meta)
            conf[k] = c
            vmeta.update(vm)
        return conf, vmeta
    def is_random(self):
        if self._is_random is not None:
            return self._is_random
        self._is_random = False
        for k, v in self.dict.items():
            if v.is_random():
                self._is_random = True
                break
        return self._is_random
    def is_grid(self):
        if self._is_grid is not None:
            return self._is_grid
        self._is_grid = True
        for k, v in self.dict.items():
            if not v.is_grid():
                self._is_grid = False
                break
        return self._is_grid
    def grid(self) -> 'tuple: list of conf, list of var_meta_dict':
        keys = list(self.dict.keys())
        keys_confs = []  # key1_confs, key2_confs, ...
        keys_metas = []  # key1_metas, key2_metas, ...
        for key in keys:
            key_confs, key_metas = self.dict[key].grid()
            keys_confs.append(key_confs)
            keys_metas.append(key_metas)
        confs = []
        for conf_list in itertools.product(keys_confs):
            # key1_conf1, key2_conf1, ...
            # key1_conf1, key2_conf2, ...
            conf = {k: v for k, v in zip(keys, conf_list)}
            confs.append(conf)
        metas = []
        for meta_list in itertools.product(keys_metas):
            # key1_meta1, key2_meta2, ...
            var_meta_dict = {}
            for sub_meta in meta_list:
                var_meta_dict.update(sub_meta)
            metas.append(var_meta_dict)
        return confs, metas

class ListAst(Ast):
    def __init__(self, expr):
        self.list = []  # Ast, ...
        for v in expr:
            self.list.append(parse_ast(v))
    def sample(self, var2meta:'dict of var_name->var_meta'=None) -> 'tuple: config, var_meta_dict':
        conf = []
        vmeta = {}
        for v in self.list:
            c, vm = v.sample(var2meta)
            conf.append(c)
            vmeta.update(vm)
        return conf, vmeta
    def grid(self) -> 'tuple: list of conf, list of var_meta_dict':
        confs_list = []  # conf1s, conf2s, ...
        metas_list = []  # meta1s, meta2s, ...
        for v in self.list:
            subconfs, submetas = v.grid()
            confs_list.append(subconfs)
            metas_list.append(submetas)
        confs = []
        metas = []
        for conf_list in itertools.product(confs_list):
            # conf1_1, conf2_1, ...
            conf = list[conf_list]
            confs.append(conf)
        for meta_list in itertools.product(metas_list):
            # meta1_1, meta2_1, ...
            var_meta_dict = {}
            for submeta in meta_list:
                var_meta_dict.update(submeta)
            metas.append(var_meta_dict)
        return confs, metas

class PrimitiveAst(Ast):
    def __init__(self, expr):
        self.value = expr  # str, float, int
        self._is_random = False

    def is_random(self):
        return self._is_random

    def sample(self, var2meta:'dict of var_name->var_meta'=None) -> 'tuple: config, var_meta_dict':
        return self.value, {}

    def grid(self) -> 'tuple: list of conf, list of var_meta_dict':
        return [self.value], [{}]

class VariableAst(Ast):
    def __init__(self, expr):
        super().__init__(expr)
        self._is_random = True
        self._is_conditional = False
        self._is_grid = False
    def is_random(self):
        return self._is_random
    def is_conditional(self):
        return self._is_conditional
    def is_grid(self):
        self._is_grid
    # def is_noncond_random(self): pass
    def perturb(self, this_meta) -> 'that_meta': pass

class ChoiceAst(VariableAst):
    def __init__(self, expr):
        super().__init__(expr)
        self.values = []  # Ast, Ast
        for v in expr['value']:
            self.values.append(parse_ast(v))
        self._is_conditional = None
        self._is_grid = True

    def perturb(self, this_meta) -> 'that_meta':
        idx = this_meta
        if np.random.rand() > 0.5:
            idx = max(0, idx-1)
        else:
            idx = min(len(self.values)-1, idx+1)
        return idx
    def sample(self, var2meta:'dict of var_name->var_meta'=None) -> 'tuple: config, var_meta_dict':
        if self.name in var2meta:
            idx = var2meta[self.name]
        else:
            idx = np.random.randint(0, len(self.values))  # random
        vmeta = {self.name: idx}
        conf, vm = self.values[idx].sample(var2meta)
        vmeta.update(vm)
        return conf, vm
    def is_conditional(self):
        if self._is_conditional is not None:
            return self._is_conditional
        self._is_conditional = False
        for v in self.values:
            if v.is_random():
                self._is_conditional = True
                break
        return self._is_conditional
    def grid(self):
        confs = []
        vmetas = []
        for idx in range(len(self.values)):
            sub_confs, sub_metas = self.values[idx].grid()
            confs.extend(sub_confs)
            vmetas.extend(sub_metas)

        return confs, vmetas

class RandintAst(VariableAst):
    def __init__(self, expr):
        self.low, self.high = expr['low'], expr['high']  # 0, 10
        self._is_grid = True
    def sample(self, var2meta:'dict of var_name->var_meta'=None) -> 'tuple: config, var_meta_dict':
        if self.name in var2meta:
            value = var2meta[self.name]
        else:
            value = np.random.randint(self.low, self.high)
        conf = value
        vmeta = {self.name: value}
        return conf, vmeta
    def grid(self) -> 'tuple: list of conf, list of var_meta_dict':
        confs = []
        metas = []
        for value in range(self.low, self.high):
            confs.append(value)
            metas.append({self.name: value})
        return confs, metas

class UniformAst(VariableAst):
    def __init__(self, expr):
        self.low, self.high = expr['low'], expr['high']  # 0., 1.

class LoguniformAst(VariableAst):
    def __init__(self, expr):
        self.low, self.high = expr['low'], expr['high']  # 1e-5, 1e-2
        self.base = expr.get('base', 10.)

class Space:
    def __init__(self):
        self.space_str = ''  # json str
        self.space_raw = {}  # after json load to python structure
        global g_vars
        g_vars = {}
        self.space_ast = parse_ast(self.space_raw)  #
        self.name2var = g_vars  # name->var_ast

    def sample(self, var2meta=None) -> 'tuple: config, var_meta_dict':
        return self.space_ast.sample(var2meta)

    def perturb(self, var2meta) -> 'tuple: config, var_meta_dict':
        new_var2meta = copy.deepcopy(var2meta)
        metas_to_perturb = var2meta  # filter
        for vname, vmeta in metas_to_perturb.items():
            new_var2meta[vname] = self.name2var[vname].perturb(vmeta)
        return self.space_ast.sample(new_var2meta)

    def grid(self):
        assert self.space_ast.is_grid()
        confs, metas = self.space_ast.grid()
        return confs, metas

语法制导翻译 NodeVisitor

#!/usr/bin/env python
# -*- coding:utf8 -*-
"""
 Author: zhaopenghao
 Create Time: 2020/5/19 上午10:55
"""
import json

import numpy as np
from rudder_autosearch.space.space import Space
from rudder_autosearch.space.space import Ast

def iter_fields(node):
    """
    Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
    that is present on *node*.
    """
    for field in node._fields:
        try:
            yield field, getattr(node, field)
        except AttributeError:
            pass

class NodeVisitor(object):
    """refers to standard library `ast.NodeVisitor`"""
    def visit(self, node):
        """Visit a node."""
        method = 'visit_' + node.__class__.__name__
        visitor = getattr(self, method, self.generic_visit)
        return visitor(node)

    def generic_visit(self, node):
        """Called if no explicit visitor function exists for a node."""
        for field, value in iter_fields(node):
            if isinstance(value, list):
                for item in value:
                    if isinstance(item, Ast):
                        self.visit(item)
            elif isinstance(value, dict):
                for key, item in value.items():
                    if isinstance(item, Ast):
                        self.visit(item)
            elif isinstance(value, Ast):
                self.visit(value)

    def space2hpo_space(self, space):
        """space2hpo_space"""
        raise NotImplementedError

    @classmethod
    def write(cls, text):
        """write"""
        cls.result += text

    @classmethod
    def clear(cls):
        """clear"""
        cls.result = ""

def visit_DictAst_public_one(node_visitor, node):
    """visit_DictAst_public_one"""
    node_visitor.write('{')
    for k, v in node.dict.items():
        node_visitor.write('"{}": '.format(k))
        node_visitor.visit(v)
        node_visitor.write(',')
    node_visitor.write('}')

def visit_ListAst_public_one(node_visitor, node):
    """visit_ListAst_public_one"""
    node_visitor.write('[')
    for v in node.list:
        node_visitor.visit(v)
        node_visitor.write(',')
    node_visitor.write(']')

def visit_PrimitiveAst_public_one(node_visitor, node):
    """visit_PrimitiveAst_public_one"""
    value = node.value
    if isinstance(value, str):
        node_visitor.write('"{}"'.format(value))
    else:
        node_visitor.write('{}'.format(value))

class HpoVisitor(NodeVisitor):
    """HyperOpt space code generator, ref to `astunparse.unparser.Unparser`"""
    result = ""

    def __init__(self):
        self.clear()

    def space2hpo_space(self, space):
        """hpo_space"""
        from hyperopt import hp
        from hyperopt.pyll import scope
        self.clear()
        self.visit(space.space_ast)
        hps = eval(self.result)
        return hps

    def visit_DictAst(self, node):
        """visit_DictAst"""
        visit_DictAst_public_one(self, node)

    def visit_ListAst(self, node):
        """visit_DictAst"""
        visit_ListAst_public_one(self, node)

    def visit_PrimitiveAst(self, node):
        """visit_PrimitiveAst"""
        visit_PrimitiveAst_public_one(self, node)

    def visit_ChoiceAst(self, node):
        """visit_ChoiceAst"""
        self.write('hp.choice("{}", ['.format(node.name))
        for v in node.values:
            self.visit(v)
            self.write(',')
        self.write('])')

    def visit_RandintAst(self, node):
        """visit_RandintAst"""
        # self.write('hp.randint("{}", {}, {}, size=None)'.format(node.name, node.low, node.high + 1))
        self.write('hp.randint("{}", {}, {})'.format(node.name, node.low, node.high + 1))

    def visit_UniformAst(self, node):
        """visit_UniformAst"""
        self.write('hp.uniform("{}", {}, {})'.format(node.name, node.low, node.high))

    def visit_LoguniformAst(self, node):
        """visit_LoguniformAst"""
        exp_low = np.log(node.low)
        exp_high = np.log(node.high)
        self.write('hp.loguniform("{}", {}, {})'.format(node.name, exp_low, exp_high))

    def visit_QuniformAst(self, node):
        """visit_QuniformAst"""
        low, high, step, dtype = node.low, node.high, node.step, node.dtype
        new_low = 0
        new_high = high - low
        # hp.quniform(label, low, high, q)
        # Returns a value like round(uniform(low, high) / q) * q. From <http://hyperopt.github.io/hyperopt/getting-started/search_spaces/>
        # So offset low to 0 first

        value = '{} + hp.quniform("{}", {}, {}, {})'.format(low, node.name, new_low, new_high, step)
        if dtype == "int":
            value = 'scope.int({})'.format(value)
        self.write(value)