Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,31 @@ python main.py \
--train
```

### Train with Proxy MSE

The Proxy module learns to approximate the white-box probability features from
the encoder hidden states. During training, `--proxy-prob` controls how often the
model consumes proxy-generated features, while `--mse-weight` supervises the
proxy features against the original white-box probability features.

```bash
python main.py \
--cuda \
--seed 2024 \
--exp-name moe+logits+cl+proxy_mse_arxiv-lora_5e-4 \
--train-path /feature/arxiv_new/lora/train.jsonl \
--val-path /feature//arxiv_new/lora/val.jsonl \
--test-path /feature/arxiv_new/lora/test_ood.jsonl \
--batch-size 64 \
--lr 5e-4 \
--is-cl \
--proxy-prob 0.5 \
--mse-weight 1.0 \
--use-curriculum \
--proxy-warmup-epochs 10 \
--train
```

### Evaluation

```bash
Expand All @@ -88,6 +113,22 @@ python main.py \
--test
```

To evaluate with proxy-generated features instead of white-box probability
features, add `--use-proxy`:

```bash
python main.py \
--cuda \
--seed 2024 \
--exp-name moe+logits+cl+proxy_mse_arxiv-lora_5e-4 \
--test-path /feature/arxiv_new/lora/test_ood.jsonl \
--batch-size 64 \
--lr 5e-4 \
--is-binary \
--use-proxy \
--test
```



## License
Expand Down
25 changes: 24 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ def parse_args():
parser.add_argument('--train', action='store_true')
parser.add_argument('--is-binary', action='store_true', help='True indicate binary classification,False indicate multi-classification')
parser.add_argument('--is-cl', action='store_true', help='if use contrastive learning')
parser.add_argument('--use-proxy', action='store_true', help='use the proxy module to replace white-box probability features at inference time')
parser.add_argument('--proxy-prob', type=float, default=0.0, help='probability of using proxy-generated probability features during training')
parser.add_argument('--proxy-warmup-epochs', type=int, default=0, help='number of epochs used to warm up proxy probability and MSE weight')
parser.add_argument('--use-curriculum', action='store_true', help='linearly increase proxy probability and MSE weight during warmup')
parser.add_argument('--mse-weight', type=float, default=0.0, help='weight for MSE loss between proxy features and white-box probability features')
return parser.parse_args()

def set_seed(seed):
Expand Down Expand Up @@ -62,7 +67,25 @@ def main(args):
train_dataloader = get_dataloader(args.train_path, args.pretrain_model, args.batch_size, args.max_len, label2id, shuffle=True) if not args.test else None
val_dataloader = get_dataloader(args.val_path, args.pretrain_model, args.batch_size, args.max_len, label2id, shuffle=False) if not args.test else None
test_dataloader = get_dataloader(args.test_path, args.pretrain_model, args.batch_size, args.max_len, label2id, shuffle=False)
trainer = Trainer(device, args.pretrain_model, train_dataloader, val_dataloader, test_dataloader, args.epoch, args.lr, args.early_stop, model_save_path, args.n_family, args.is_cl, args.is_binary)
trainer = Trainer(
device,
args.pretrain_model,
train_dataloader,
val_dataloader,
test_dataloader,
args.epoch,
args.lr,
args.early_stop,
model_save_path,
args.n_family,
args.is_cl,
args.is_binary,
proxy_prob=args.proxy_prob,
proxy_warmup_epochs=args.proxy_warmup_epochs,
mse_weight=args.mse_weight,
use_proxy_inference=args.use_proxy,
use_curriculum=args.use_curriculum,
)

if not args.test:
trainer.train()
Expand Down
64 changes: 54 additions & 10 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@
from scl_loss import SupConLoss
from typing import List, Tuple

class ProxyModule(nn.Module):
def __init__(self, emb_dim, n_feat):
super(ProxyModule, self).__init__()
self.net = nn.Sequential(
nn.Linear(emb_dim, 256),
nn.ReLU(),
nn.Linear(256, n_feat)
)

def forward(self, feature):
out = self.net(feature)
return out.transpose(1, 2)

class MLP(nn.Module):
def __init__(self, input_dim, hidden_dims, output_dim, dropout):
super(MLP, self).__init__()
Expand Down Expand Up @@ -87,7 +100,8 @@ def __init__(self, n_family=4, emb_dim=768, hidden_dims=[256], dropout=0.2, feat
super(Model, self).__init__()
self.n_family = n_family

self.n_feat = 3
self.n_feat = n_family - 1
self.proxy_module = ProxyModule(emb_dim, self.n_feat)
feature_enc_layers = [(64, 5, 1)] + [(128, 3, 1)] * 3 + [(64, 3, 1)]
self.conv = ConvFeatureExtractionModel(
conv_layers=feature_enc_layers,
Expand Down Expand Up @@ -128,9 +142,17 @@ def conv_feat_extract(self, x):
out = out.transpose(1, 2)
return out

def forward(self, prob_feature, feature, mask):
def forward(self, prob_feature, feature, mask, use_proxy_prob=0.0):
pred_prob_feature = self.proxy_module(feature)

if self.training and use_proxy_prob > 0.0:
if torch.rand((), device=feature.device).item() < use_proxy_prob:
prob_feature = pred_prob_feature
elif not self.training and use_proxy_prob >= 1.0:
prob_feature = pred_prob_feature

prob_feature = torch.cat([self.conv_feat_extract(prob_feature[:, i:i+1, :]) for i in range(self.n_feat)], dim=2) # (batch_size, seq_len, embedding_size)
prob_feature = prob_feature + self.position_encoding.cuda()
prob_feature = prob_feature + self.position_encoding.to(prob_feature.device)
prob_feature = self.norm(prob_feature)
prob_feature = self.encoder(prob_feature)
prob_feature = self.dropout(prob_feature) # (bs, seq_len, embedding_size)
Expand All @@ -142,19 +164,26 @@ def forward(self, prob_feature, feature, mask):

shared_feature = sum([self.expert[i](prob_feature) * gate[:, i].unsqueeze(1) for i in range(self.n_family)])
pred_binary = self.binary_classifier(shared_feature)
pred_binary = torch.sigmoid(pred_binary).squeeze()
pred_binary = torch.sigmoid(pred_binary).squeeze(-1)

return pred_binary, pred_family, family_feature
return pred_binary, pred_family, family_feature, pred_prob_feature

class Trainer:
def __init__(self, device, pretrain_model, train_dataloader, val_dataloader, test_dataloader, epoch, lr, early_stop, model_save_path, n_family, is_cl, is_binary):
def __init__(self, device, pretrain_model, train_dataloader, val_dataloader, test_dataloader, epoch, lr, early_stop, model_save_path, n_family, is_cl, is_binary, proxy_prob=0.0, proxy_warmup_epochs=0, mse_weight=0.0, use_proxy_inference=False, use_curriculum=False):
self.device = device
self.epoch = epoch
self.train_dataloader = train_dataloader
self.val_dataloader = val_dataloader
self.test_dataloader = test_dataloader
self.early_stop = early_stop
self.n_family = n_family
self.proxy_prob = proxy_prob
self.proxy_warmup_epochs = proxy_warmup_epochs
self.mse_weight = mse_weight
self.use_proxy_inference = use_proxy_inference
self.use_curriculum = use_curriculum
self.current_proxy_prob = 0.0
self.current_mse_weight = 0.0
self.pretrain = RobertaModel.from_pretrained(pretrain_model).to(device)
self.model_save_path = model_save_path
self.model = Model(n_family=n_family).to(device)
Expand All @@ -170,14 +199,18 @@ def get_loss(self, batch):
label_family = batch['label_family'].to(self.device)
label_binary = batch['label_binary'].to(self.device)

pred_binary, pred_family, family_feature = self.model(ll_tokens_list, feature, attention_mask)
pred_binary, pred_family, family_feature, pred_prob_feature = self.model(ll_tokens_list, feature, attention_mask, use_proxy_prob=self.current_proxy_prob)
if self.is_clLoss:
loss = nn.BCELoss()(pred_binary, label_binary.float()) \
+ nn.CrossEntropyLoss()(pred_family, label_family) \
+ SupConLoss(temperature=0.1)(family_feature.unsqueeze(dim=-1), label_family)
else:
loss = nn.BCELoss()(pred_binary, label_binary.float()) \
+ nn.CrossEntropyLoss()(pred_family, label_family)
if self.current_mse_weight > 0.0:
real_targets = ll_tokens_list[:, :self.model.n_feat, :]
loss_mse = nn.MSELoss()(pred_prob_feature, real_targets)
loss = loss + self.current_mse_weight * loss_mse
return loss

def get_output(self, batch):
Expand All @@ -187,23 +220,34 @@ def get_output(self, batch):
attention_mask = batch['attention_mask'].to(self.device)
feature = self.pretrain(input_ids, attention_mask).last_hidden_state.detach()
with torch.no_grad():
output, _, _ = self.model(ll_tokens_list, feature, attention_mask)
output, _, _, _ = self.model(ll_tokens_list, feature, attention_mask, use_proxy_prob=1.0 if self.use_proxy_inference else 0.0)
return output
else:
ll_tokens_list = batch['ll_tokens_list'].to(self.device)
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
feature = self.pretrain(input_ids, attention_mask).last_hidden_state.detach()
with torch.no_grad():
output, pred_family, _ = self.model(ll_tokens_list, feature, attention_mask)
output, pred_family, _, _ = self.model(ll_tokens_list, feature, attention_mask, use_proxy_prob=1.0 if self.use_proxy_inference else 0.0)
return output, pred_family



def train(self):
recorder = Recorder(self.early_stop)
for epoch in range(self.epoch):
print('----epoch %d----' % (epoch+1))
if self.use_curriculum and self.proxy_warmup_epochs > 0:
warmup_ratio = min(1.0, (epoch + 1) / self.proxy_warmup_epochs)
self.current_proxy_prob = self.proxy_prob * warmup_ratio
self.current_mse_weight = self.mse_weight * warmup_ratio
else:
self.current_proxy_prob = self.proxy_prob
self.current_mse_weight = self.mse_weight

if self.proxy_prob > 0.0 or self.mse_weight > 0.0:
print('----epoch %d (proxy_prob: %.2f, mse_weight: %.2f)----' % (epoch+1, self.current_proxy_prob, self.current_mse_weight))
else:
print('----epoch %d----' % (epoch+1))
self.model.train()
avg_loss = Averager()
for i, batch in enumerate(tqdm(self.train_dataloader)):
Expand Down