CausalBert API
Example
This implementation of CausalBert
was adapted from Causal Effects of Linguistic Properties by Pryzant et al. CausalBert
is essentially a kind of S-Learner that uses a DistilBert sequence classification model as the base learner.
import pandas as pd
df = pd.read_csv('sample_data/music_seed50.tsv', sep='\t', error_bad_lines=False)
from causalnlp.core.causalbert import CausalBertModel
cb = CausalBertModel(batch_size=32, max_length=128)
cb.train(df['text'], df['C_true'], df['T_ac'], df['Y_sim'], epochs=1, learning_rate=2e-5)
print(cb.estimate_ate(df['C_true'], df['text']))
(Reduce the batch_size
if you receive an Out-Of-Memory error when running the code above.)