Over the last year, ChemBERTa has evolved from a singular Colab notebook training a small transformer model on ZINC 100k, to a larger project encompassing support for masked language modelling of large chemical libraries, molecular property prediction (both single-task and regression), alongside attention visualization and other utilities.
However, there is still a long way to figure out how to better embed this infrastructure within the wider DeepChem ecosystem. Doing so, however, is incredibly valuable towards establishing a consistent API design for incorporating NLP models and tasks across a variety of disciplines (molecular property prediction, retrosynthesis, protein structure prediction). I hope this note outlines ChemBERTa's current layout and suggests how to incorporate it within the DeepChem library to support a wider community and easier usage of transformer models in molecular ML.
Here are the rough sequence of the next steps:
ChemBERTa code is written in PyTorch (similar to huggingface/transformers), and so all model implementations would go under deepchem/deepchem/models/torch_models
. ChemBERTa utilizes a base Roberta model, and so all models subclass the RobertaPreTrainedModel
in hf/transformers
. This is an abstract class to handle weights initialization and a simple interface for downloading and loading pre-trained models (through the HF model hub).
All of the models follow the same basic layout:
class RobertaforMiscTask(RobertaPreTrainedModel):
def __init__(self, config):
# initialize RobertaModel, RobertaClassificationHead/RobertaRegressionHead, etc
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return RegressionOutput/MaskedLMOutput/SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# returns base class for misc. model outputs
# initialize type of loss (MSE or cross-entropy)
To adapt these model implementations into DeepChem, we will have to adapt the models to use multiple inheritance from TorchModel (as all other PyTorch models in DC do) and RobertaPreTrainedModel:
class RobertaforMiscTask(RobertaPreTrainedModel, TorchModel):
The goal is to add all three model implementations (RobertaforSequenceClassification, RobertaRegression, RobertaforMaskedLM
) into dc.model.torch_model
alongside their equivalent base class for outputs.
The next goal is to adopt a SmilesFeaturizer class which inherits from MolecularFeaturizer and RobertaTokenizerFast, to complement the existing SmilesTokenizer (which uses regex).