CausalBert API

platt_scale[source]

platt_scale(outcome, probs)

gelu[source]

gelu(x)

make_bow_vector[source]

make_bow_vector(ids, vocab_size, use_counts=False)

Make a sparse BOW vector from a tensor of dense ids. Args: ids: torch.LongTensor [batch, features]. Dense tensor of ids. vocab_size: vocab size for this tensor. use_counts: if true, the outgoing BOW vector will contain feature counts. If false, will contain binary indicators. Returns: The sparse bag-of-words representation of ids.

class CausalBert[source]

CausalBert(config) :: DistilBertPreTrainedModel

CausalBert is essentially an S-Learner that uses a DistilBert sequence classification model as the base learner.

class CausalBertModel[source]

CausalBertModel(g_weight=0.0, Q_weight=0.1, mlm_weight=1.0, batch_size=32, max_length=128, model_name='distilbert-base-uncased')

CausalBertModel is a wrapper for CausalBert

CausalBertModel.train[source]

CausalBertModel.train(texts, confounds, treatments, outcomes, learning_rate=2e-05, epochs=3)

Trains a CausalBert model

CausalBertModel.estimate_ate[source]

CausalBertModel.estimate_ate(C, W, Y=None, platt_scaling=False)

Computes average treatment effect using the trained estimator

CausalBertModel.inference[source]

CausalBertModel.inference(texts, confounds, outcome=None)

Perform inference using the trained model

Example

This implementation of CausalBert was adapted from Causal Effects of Linguistic Properties by Pryzant et al. CausalBert is essentially a kind of S-Learner that uses a DistilBert sequence classification model as the base learner.

import pandas as pd
df = pd.read_csv('sample_data/music_seed50.tsv', sep='\t', error_bad_lines=False)
from causalnlp.core.causalbert import CausalBertModel
cb = CausalBertModel(batch_size=32, max_length=128)
cb.train(df['text'], df['C_true'], df['T_ac'], df['Y_sim'], epochs=1, learning_rate=2e-5)
print(cb.estimate_ate(df['C_true'], df['text']))
Some weights of CausalBert were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['Q_cls.1.0.bias', 'Q_cls.0.0.bias', 'g_cls.weight', 'Q_cls.1.0.weight', 'g_cls.bias', 'Q_cls.1.2.bias', 'Q_cls.0.2.weight', 'Q_cls.0.0.weight', 'Q_cls.0.2.bias', 'Q_cls.1.2.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 666/666 [02:12<00:00,  5.01it/s]
100%|██████████| 666/666 [00:27<00:00, 24.32it/s]
0.17478953341997637

(Reduce the batch_size if you receive an Out-Of-Memory error when running the code above.)