import pandas as pd
from causalnlp.core.causalbert import CausalBertModel
CausalBert
CausalBertModel
CausalBertModel (g_weight=0.0, Q_weight=0.1, mlm_weight=1.0, batch_size=32, max_length=128, model_name='distilbert- base-uncased')
CausalBertModel is a wrapper for CausalBert
CausalBert
CausalBert (config)
CausalBert is essentially an S-Learner that uses a DistilBert sequence classification model as the base learner.
make_bow_vector
make_bow_vector (ids, vocab_size, use_counts=False)
Make a sparse BOW vector from a tensor of dense ids. Args: ids: torch.LongTensor [batch, features]. Dense tensor of ids. vocab_size: vocab size for this tensor. use_counts: if true, the outgoing BOW vector will contain feature counts. If false, will contain binary indicators. Returns: The sparse bag-of-words representation of ids.
gelu
gelu (x)
platt_scale
platt_scale (outcome, probs)
CausalBertModel.train
CausalBertModel.train (texts, confounds, treatments, outcomes, learning_rate=2e-05, epochs=3)
Trains a CausalBert model
CausalBertModel.estimate_ate
CausalBertModel.estimate_ate (C, W, Y=None, platt_scaling=False)
Computes average treatment effect using the trained estimator
CausalBertModel.inference
CausalBertModel.inference (texts, confounds, outcome=None)
Perform inference using the trained model
Example
This implementation of CausalBert
was adapted from Causal Effects of Linguistic Properties by Pryzant et al. CausalBert
is essentially a kind of S-Learner that uses a DistilBert sequence classification model as the base learner.
= pd.read_csv('sample_data/music_seed50.tsv', sep='\t', on_bad_lines='skip')
df = CausalBertModel(batch_size=32, max_length=128)
cb 'text'], df['C_true'], df['T_ac'], df['Y_sim'], epochs=1, learning_rate=2e-5)
cb.train(df[print(cb.estimate_ate(df['C_true'], df['text']))
Some weights of CausalBert were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['Q_cls.1.0.bias', 'Q_cls.0.0.bias', 'g_cls.weight', 'Q_cls.1.0.weight', 'g_cls.bias', 'Q_cls.1.2.bias', 'Q_cls.0.2.weight', 'Q_cls.0.0.weight', 'Q_cls.0.2.bias', 'Q_cls.1.2.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 666/666 [02:12<00:00, 5.01it/s]
100%|██████████| 666/666 [00:27<00:00, 24.32it/s]
0.17478953341997637
(Reduce the batch_size
if you receive an Out-Of-Memory error when running the code above.)