CausalBert

CausalBert API

source

CausalBertModel

 CausalBertModel (g_weight=0.0, Q_weight=0.1, mlm_weight=1.0,
                  batch_size=32, max_length=128, model_name='distilbert-
                  base-uncased')

CausalBertModel is a wrapper for CausalBert


source

CausalBert

 CausalBert (config)

CausalBert is essentially an S-Learner that uses a DistilBert sequence classification model as the base learner.


source

make_bow_vector

 make_bow_vector (ids, vocab_size, use_counts=False)

Make a sparse BOW vector from a tensor of dense ids. Args: ids: torch.LongTensor [batch, features]. Dense tensor of ids. vocab_size: vocab size for this tensor. use_counts: if true, the outgoing BOW vector will contain feature counts. If false, will contain binary indicators. Returns: The sparse bag-of-words representation of ids.


source

gelu

 gelu (x)

source

platt_scale

 platt_scale (outcome, probs)

source

CausalBertModel.train

 CausalBertModel.train (texts, confounds, treatments, outcomes,
                        learning_rate=2e-05, epochs=3)

Trains a CausalBert model


source

CausalBertModel.estimate_ate

 CausalBertModel.estimate_ate (C, W, Y=None, platt_scaling=False)

Computes average treatment effect using the trained estimator


source

CausalBertModel.inference

 CausalBertModel.inference (texts, confounds, outcome=None)

Perform inference using the trained model

Example

This implementation of CausalBert was adapted from Causal Effects of Linguistic Properties by Pryzant et al. CausalBert is essentially a kind of S-Learner that uses a DistilBert sequence classification model as the base learner.

import pandas as pd
from causalnlp.core.causalbert import CausalBertModel
df = pd.read_csv('sample_data/music_seed50.tsv', sep='\t', on_bad_lines='skip')
cb = CausalBertModel(batch_size=32, max_length=128)
cb.train(df['text'], df['C_true'], df['T_ac'], df['Y_sim'], epochs=1, learning_rate=2e-5)
print(cb.estimate_ate(df['C_true'], df['text']))
Some weights of CausalBert were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['Q_cls.1.0.bias', 'Q_cls.0.0.bias', 'g_cls.weight', 'Q_cls.1.0.weight', 'g_cls.bias', 'Q_cls.1.2.bias', 'Q_cls.0.2.weight', 'Q_cls.0.0.weight', 'Q_cls.0.2.bias', 'Q_cls.1.2.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 666/666 [02:12<00:00,  5.01it/s]
100%|██████████| 666/666 [00:27<00:00, 24.32it/s]
0.17478953341997637

(Reduce the batch_size if you receive an Out-Of-Memory error when running the code above.)