-
Notifications
You must be signed in to change notification settings - Fork 207
Expand file tree
/
Copy pathxlnet_processor.py
More file actions
201 lines (179 loc) · 8.76 KB
/
xlnet_processor.py
File metadata and controls
201 lines (179 loc) · 8.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import csv
import torch
import numpy as np
from ..common.tools import load_pickle
from ..common.tools import logger
from ..callback.progressbar import ProgressBar
from torch.utils.data import TensorDataset
from transformers import XLNetTokenizer
class InputExample(object):
def __init__(self, guid, text_a, text_b=None, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
class InputFeature(object):
'''
A single set of features of data.
'''
def __init__(self,input_ids,input_mask,segment_ids,label_id,input_len):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
self.input_len = input_len
class XlnetProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def __init__(self,vocab_path,do_lower_case):
self.tokenizer = XLNetTokenizer(vocab_path,do_lower_case)
def get_train(self, data_file):
"""Gets a collection of `InputExample`s for the train set."""
return self.read_data(data_file)
def get_dev(self, data_file):
"""Gets a collection of `InputExample`s for the dev set."""
return self.read_data(data_file)
def get_test(self,lines):
return lines
def get_labels(self):
"""Gets the list of labels for this data set."""
return ["toxic","severe_toxic","obscene","threat","insult","identity_hate"]
@classmethod
def read_data(cls, input_file,quotechar = None):
"""Reads a tab separated value file."""
if 'pkl' in str(input_file):
lines = load_pickle(input_file)
else:
lines = input_file
return lines
def truncate_seq_pair(self,tokens_a,tokens_b,max_length):
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
def create_examples(self,lines,example_type,cached_examples_file):
'''
Creates examples for data
'''
pbar = ProgressBar(n_total=len(lines),desc='create examples')
if cached_examples_file.exists():
logger.info("Loading examples from cached file %s", cached_examples_file)
examples = torch.load(cached_examples_file)
else:
examples = []
for i,line in enumerate(lines):
guid = '%s-%d'%(example_type,i)
text_a = line[0]
label = line[1]
if isinstance(label,str):
label = [float(x) for x in label.split(",")]
else:
label = [float(x) for x in list(label)]
text_b = None
example = InputExample(guid = guid,text_a = text_a,text_b=text_b,label= label)
examples.append(example)
pbar(step = i)
logger.info("Saving examples into cached file %s", cached_examples_file)
torch.save(examples, cached_examples_file)
return examples
def create_features(self,examples,max_seq_len,cached_features_file):
'''
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
'''
# Load data features from cache or dataset file
pbar = ProgressBar(n_total=len(examples),desc='create features')
if cached_features_file.exists():
logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file)
else:
features = []
pad_token = self.tokenizer.convert_tokens_to_ids([self.tokenizer.pad_token])[0]
cls_token = self.tokenizer.cls_token
sep_token = self.tokenizer.sep_token
cls_token_segment_id = 2
pad_token_segment_id = 4
for ex_id,example in enumerate(examples):
tokens_a = self.tokenizer.tokenize(example.text_a)
tokens_b = None
label_id = example.label
if example.text_b:
tokens_b = self.tokenizer.tokenize(example.text_b)
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
self.truncate_seq_pair(tokens_a,tokens_b,max_length = max_seq_len - 3)
else:
# Account for [CLS] and [SEP] with '-2'
if len(tokens_a) > max_seq_len - 2:
tokens_a = tokens_a[:max_seq_len - 2]
# xlnet has a cls token at the end
tokens = tokens_a + [sep_token]
segment_ids = [0] * len(tokens)
if tokens_b:
tokens += tokens_b + [sep_token]
segment_ids += [1] * (len(tokens_b) + 1)
tokens += [cls_token]
segment_ids += [cls_token_segment_id]
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
input_mask = [1] * len(input_ids)
input_len = len(input_ids)
padding_len = max_seq_len - len(input_ids)
# pad on the left for xlnet
input_ids = ([pad_token] * padding_len) + input_ids
input_mask = ([0 ] * padding_len) + input_mask
segment_ids = ([pad_token_segment_id] * padding_len) + segment_ids
assert len(input_ids) == max_seq_len
assert len(input_mask) == max_seq_len
assert len(segment_ids) == max_seq_len
if ex_id < 2:
logger.info("*** Example ***")
logger.info(f"guid: {example.guid}" % ())
logger.info(f"tokens: {' '.join([str(x) for x in tokens])}")
logger.info(f"input_ids: {' '.join([str(x) for x in input_ids])}")
logger.info(f"input_mask: {' '.join([str(x) for x in input_mask])}")
logger.info(f"segment_ids: {' '.join([str(x) for x in segment_ids])}")
feature = InputFeature(input_ids = input_ids,
input_mask = input_mask,
segment_ids = segment_ids,
label_id = label_id,
input_len = input_len)
features.append(feature)
pbar(step=ex_id)
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)
return features
def create_dataset(self,features,is_sorted = False):
# Convert to Tensors and build dataset
if is_sorted:
logger.info("sorted data by th length of input")
features = sorted(features,key=lambda x:x.input_len,reverse=True)
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in features],dtype=torch.long)
all_input_lens = torch.tensor([f.input_len for f in features], dtype=torch.long)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids,all_input_lens)
return dataset