Source code for dataloader.element_set
"""
.. module:: element_set
:synopsis: data loader for element set
.. moduleauthor:: Jiaming Shen
"""
import torch
import math
import numpy as np
import random
[docs]class ElementSet(object):
""" Dataset Object
:param name: dataset name
:type name: str
:param data_format: dataset format, either "set" or "sip"
:type data_format: str
:param options: dataset parameters, including two dicts mapping element to element index
:type options: dict
:param raw_data_strings: a list of strings representing an element set.
- If data_format is "set", each string is of format "c0 {'d93', 'd377', 'd141', 'd63', 'd166'}".
- If data_format is "sip", each string is of format "{'d93', 'd377'} d141 0".
:type raw_data_strings: list
"""
def __init__(self, name, data_format, options, raw_data_strings=None):
self.name = name
self.data_format = data_format
self.index2word = options["index2word"]
self.word2index = options["word2index"]
self.device = options["device"]
self.vocab = [] # this vocab will contain only instances that appear in the above positive_set at least once
self.max_set_size = -1 # the max_set_size in this dataset
self.min_set_size = 1e8 # the min_set_size in this dataset
self.avg_set_size = -1 # the avg_set_size in this dataset
# a list of element sets
self.positive_sets = []
# a list of <set, instance> pairs
self.sip_triplets = []
self.pos_sip_cnt = -1
self.neg_sip_cnt = -1
# used to generate a collection of <set, instance> pair for evaluation in advance.
self.NEG_SAMPLE_RATIO = 10 # for each positive (set, instance) pair, generate at most 10 negative pairs
self.MAX_POS_SUB_SET_CNT = 500 # for each full set, generate at most 500 positive (set, instance) pairs
if self.data_format == "set":
self._initialize_set_format(raw_data_strings)
elif self.data_format == "sip":
self._initialize_sip_format(raw_data_strings)
# for test set, generate sip triplets for evaluation of set-instance prediction, the negative sampling strategy,
# negative sample size, and max set size are all prefixed
if "test" in self.name and self.data_format == "set":
self.sip_triplets, self.pos_sip_cnt, self.neg_sip_cnt = self._convert_set_format_to_sip_format(
raw_sets=self.positive_sets, pos_strategy="vary_size_enumerate_with_full_set",
neg_strategy="complete-random", neg_sample_size=10, max_set_size=50)
def __repr__(self):
return "<ElementSet {} (data_format = {}, vocab_size = {}, number of sets = {}, " \
"max_set_size = {}, min_set_size = {}, avg_set_size = {}, number of set-instance pairs = {}, " \
"positive pairs = {}, negative pairs = {})>".format(self.name, self.data_format, len(self.vocab),
len(self.positive_sets), self.max_set_size, self.min_set_size, self.avg_set_size,
len(self.sip_triplets),self.pos_sip_cnt, self.neg_sip_cnt)
def __len__(self):
if self.data_format == "set":
return len(self.positive_sets)
elif self.data_format == "sip":
return len(self.sip_triplets)
[docs] def _initialize_set_format(self, raw_set_strings):
"""Initialize dataset from a collection of strings representing element sets
:param raw_set_strings: a list of strings representing element sets
:type raw_set_strings: list
:return: None
:rtype: None
"""
set_size_sum = 0 # used to calculate self.avg_set_size
for line in raw_set_strings:
line = line.strip()
eid, cls = line.split(" ", 1)
cls = sorted(list(eval(cls))) # sorting for reproducibility
self.max_set_size = max(self.max_set_size, len(cls))
self.min_set_size = min(self.min_set_size, len(cls))
set_size_sum += len(cls)
self.positive_sets.append(sorted([self.word2index[ele] for ele in cls])) # sorting for reproducibility
self.vocab.extend([self.word2index[ele] for ele in cls])
self.avg_set_size = 1.0 * set_size_sum / len(self.positive_sets)
self.vocab = sorted(list(set(self.vocab))) # sorting for reproducibility
[docs] def _initialize_sip_format(self, raw_set_instance_strings):
""" Initialize dataset from a collection of strings representing <set, instance> pairs
:param raw_set_instance_strings: a list of strings representing <set instance> pairs
:type raw_set_instance_strings: list
:return: None
:rtype: None
"""
for line in raw_set_instance_strings:
line = line.strip()
segs = line.split(" ")
label = int(segs[-1])
instance = self.word2index[segs[-2]]
subset = sorted(list([self.word2index[ele] for ele in eval(" ".join(segs[:-2]))])) # sorting for reproducibility
self.sip_triplets.append((subset, instance, label))
if label == 1:
self.pos_sip_cnt += 1
else:
self.neg_sip_cnt += 1
self.vocab.extend(subset)
self.max_set_size = max(self.max_set_size, len(subset)+1)
self.vocab = sorted(list(set(self.vocab))) # sorting for reproducibility
[docs] def get_train_batch(self, max_set_size=100, pos_sample_method="sample_size_random_set", neg_sample_size=1,
neg_sample_method="complete_random", batch_size=32):
""" Generate one training batch of <set, instance> pairs
:param max_set_size: maximum size of set S
:type max_set_size: int
:param pos_sample_method: name of positive sampling method
:type pos_sample_method: str
:param neg_sample_size: number of negative samples for each set
:type neg_sample_size: int
:param neg_sample_method: name of negative sampling method
:type neg_sample_method: str
:param batch_size: number of **sets** in one batch
:type batch_size: int
:return: a training batch containing "batch_size * (1+neg_sample_size)" <set, instance> pairs
:rtype: dict
"""
if self.data_format == "set":
raw_sets = []
for raw_set in self.positive_sets:
raw_sets.append(raw_set)
if len(raw_sets) % batch_size == 0:
sip_triplets = self._convert_set_format_to_sip_format(raw_sets=raw_sets,
pos_strategy=pos_sample_method,
neg_strategy=neg_sample_method,
neg_sample_size=neg_sample_size,
max_set_size=max_set_size)
batch_set = []
batch_inst = []
labels = []
for sip_triplet in sip_triplets:
batch_set.append(sip_triplet[0])
batch_inst.append(sip_triplet[1])
labels.append(sip_triplet[2])
batch = self._convert_sip_format_to_tensor(max_set_size, batch_set, batch_inst, labels)
yield batch
raw_sets = []
# yield the last batch
if len(raw_sets) != 0:
sip_triplets = self._convert_set_format_to_sip_format(raw_sets=raw_sets, pos_strategy=pos_sample_method,
neg_strategy=neg_sample_method,
neg_sample_size=neg_sample_size,
max_set_size=max_set_size)
batch_set = []
batch_inst = []
labels = []
for sip_triplet in sip_triplets:
batch_set.append(sip_triplet[0])
batch_inst.append(sip_triplet[1])
labels.append(sip_triplet[2])
batch = self._convert_sip_format_to_tensor(max_set_size, batch_set, batch_inst, labels)
yield batch
elif self.data_format == "sip":
batch_set = []
batch_inst = []
labels = []
for sip_triplet in self.sip_triplets:
batch_set.append(sip_triplet[0])
batch_inst.append(sip_triplet[1])
labels.append(sip_triplet[2])
if len(batch_set) % (batch_size * (1+neg_sample_size)) == 0:
batch = self._convert_sip_format_to_tensor(max_set_size, batch_set, batch_inst, labels)
yield batch
batch_set = []
batch_inst = []
labels = []
if len(batch_set) != 0:
batch = self._convert_sip_format_to_tensor(max_set_size, batch_set, batch_inst, labels)
yield batch
[docs] def get_test_batch(self, max_set_size=5, batch_size=32):
""" Generate one testing batch of <set, instance> pairs
:param max_set_size: maximum size of set S
:type max_set_size: int
:param batch_size: number of **<set, instance> pairs** in one batch
:type batch_size: int
:return: a testing batch containing "batch_size" <set, instance> pairs
:rtype: dict
"""
batch_set = []
batch_inst = []
labels = []
for idx, batch in enumerate(self.sip_triplets):
batch_set.append(batch[0])
batch_inst.append(batch[1])
labels.append(batch[2])
# convert to tensor, yield a batch, clean buffer
if (idx+1) % batch_size == 0:
res = self._convert_sip_format_to_tensor(max_set_size, batch_set, batch_inst, labels)
yield res
batch_set = []
batch_inst = []
labels = []
# yield the last batch
if (idx + 1) != len(self.sip_triplets):
res = self._convert_sip_format_to_tensor(max_set_size, batch_set, batch_inst, labels)
yield res
[docs] def _shuffle(self):
""" Shuffle dataset
:return: None
:rtype: None
"""
if self.data_format == "set":
random.shuffle(self.positive_sets)
elif self.data_format == "sip":
random.shuffle(self.sip_triplets)
[docs] def _convert_set_format_to_sip_format(self, raw_sets, pos_strategy, neg_strategy, neg_sample_size=10,
subset_size=5, max_set_size=50):
""" Generate <set, instance> pairs (sip) from a collection of sets
:param raw_sets: a list of sets
:type raw_sets: list
:param pos_strategy: name of positive sampling method
:type pos_strategy: str
:param neg_strategy: name of negative sampling method
:type neg_strategy: str
:param neg_sample_size: negative sampling ratio
:type neg_sample_size: int
:param subset_size: size of "set" in <set, instance> pairs, used only in "fix_size_repeat_set" pos_strategy
:type subset_size: int
:param max_set_size: maximum size of "set" in <set, instance> pairs, used only in "vary_size_enumerate" pos_strategy
:type max_set_size: int
:return: len(raw_sets) * (1 + neg_sample_size) sips, among which len(raw_sets) sips are positive and len(raw_sets) * neg_sample_size sips are negative
:rtype: list
Notes:
- if pos_strategy is "sample_size_repeat_set", for each original set, we sample the size of "set" in sip, repeat this generated set neg_sample_size times, and pair them with each negative instance. This is the strategy to original AAAI submission.
- if pos_strategy is "sample_size_random_set", for each original set, we sample one size of "set" in sip, and generate one set for each negative instance.
- if pos_strategy is "fix_size_repeat_set", for each original set, we use pre-determined subset size to generate one "set" in sip, repeat this generated set neg_sample_size times, and pair them with each negative instance. This is the one used in cold-start training.
- if pos_strategy is "vary_size_enumerate", for each original set and for each subset size less than max_set_size, we enumerate the original set and generate all possible sips. This is the one used for converting test_set in "set" format to "sip" format.
- if pos_strategy is "vary_size_enumerate_with_full_set", it's basically same as the "vary_size_enumerate" strategy, except that it will also generate full set with only negative instances
- if pos_strategy is "vary_size_enumerate_with_full_set_plus_group_id", it's basically same as the "vary_size_enumerate_with_full_set" strategy, expect that it will also return the group id of each sip the group id is this sip's corresponding raw set index
- if pos_strategy is "enumerate", this is the one used for pre-generating sip triplets
"""
if pos_strategy == "sample_size_repeat_set":
batch_set = []
batch_pos = []
batch_fullset = [] # used to generate negative samples
for raw_set in raw_sets:
if len(raw_set) == 1:
batch_set.append(raw_set)
batch_pos.append(raw_set[0])
batch_fullset.append(raw_set)
continue
raw_set_new = raw_set.copy()
random.shuffle(raw_set_new)
batch_fullset.append(raw_set_new)
subset_size = random.randint(1, len(raw_set_new) - 1)
subset = raw_set_new[:subset_size]
pos_inst = raw_set_new[subset_size]
batch_set.append(subset)
batch_pos.append(pos_inst)
# Randomly generate negative instances
batch_neg = self._generate_negative_samples_within_pool(batch_fullset, neg_sample_size, remove_pos=True)
# Convert to sip formats, notice here the subset is repeated (1+neg_sample_size) times
sip_triplets = []
for idx, subset in enumerate(batch_set):
sip_triplets.append((subset, batch_pos[idx], 1))
for neg_inst in batch_neg[idx]:
sip_triplets.append((subset, neg_inst, 0))
return sip_triplets
elif pos_strategy == "sample_size_random_set":
batch_neg = self._generate_negative_samples_within_pool(raw_sets, neg_sample_size, remove_pos=True)
# print("neg_sample_size: {}".format(neg_sample_size))
# print("batch_neg:", batch_neg)
batch_set = []
batch_pos = []
for raw_set in raw_sets:
if len(raw_set) == 1:
batch_set.append([raw_set for _ in range(neg_sample_size+1)])
batch_pos.append(raw_set[0])
continue
k_set = []
raw_set_new = raw_set.copy()
random.shuffle(raw_set_new)
subset_size_range = min(len(raw_set_new), max_set_size)
subset_size = random.randint(1, subset_size_range - 1)
pos_inst = raw_set_new[0]
start_idx = 1 # treat the first element as positive instance and skip it
while len(k_set) != (1+neg_sample_size):
if start_idx + subset_size > len(raw_set_new): # consume current pass, resample subset size
random.shuffle(raw_set_new)
subset_size = random.randint(1, subset_size_range - 1)
start_idx = 0
k_set.append(raw_set_new[start_idx: start_idx+subset_size])
start_idx += subset_size
batch_set.append(k_set)
batch_pos.append(pos_inst)
# Convert to sip formats, notice here the subset is repeated (1+neg_sample_size) times
sip_triplets = []
for idx, subset_list in enumerate(batch_set):
for idy, subset in enumerate(subset_list):
if idy == 0:
sip_triplets.append((subset, batch_pos[idx], 1))
else:
# print("idx: {}, idy:{}".format(idx, idy))
sip_triplets.append((subset, batch_neg[idx][idy-1], 0))
return sip_triplets
elif pos_strategy == "fix_size_repeat_set":
batch_set = []
batch_pos = []
batch_fullset = [] # used to generate negative samples
for raw_set in raw_sets:
if len(raw_set) < subset_size+1: # if we cannot sample a sip in which the "set" is of size subset_size
if len(raw_set) == 1:
batch_set.append(raw_set)
batch_pos.append(raw_set[0])
batch_fullset.append(raw_set)
else:
raw_set_new = raw_set.copy()
random.shuffle(raw_set_new)
batch_set.append(raw_set[1:]) # select maximum size of subset_size
batch_pos.append(raw_set[0])
batch_fullset.append(raw_set)
continue
raw_set_new = raw_set.copy()
random.shuffle(raw_set_new)
batch_fullset.append(raw_set_new)
# use given subset_size to generate sips
subset = raw_set_new[:subset_size]
pos_inst = raw_set_new[subset_size]
batch_set.append(subset)
batch_pos.append(pos_inst)
# Randomly generate negative instances
batch_neg = self._generate_negative_samples_within_pool(batch_fullset, neg_sample_size, remove_pos=True)
# Convert to sip formats, notice here the subset is repeated (1+neg_sample_size) times
sip_triplets = []
for idx, subset in enumerate(batch_set):
sip_triplets.append((subset, batch_pos[idx], 1))
for neg_inst in batch_neg[idx]:
sip_triplets.append((subset, neg_inst, 0))
return sip_triplets
elif pos_strategy == "vary_size_enumerate":
sip_triplets = []
pos_sip_cnt_sum = 0
neg_sip_cnt_sum = 0
for subset_size in range(1, max_set_size+1):
for raw_set in raw_sets:
if len(raw_set) < subset_size + 1:
continue
raw_set_new = raw_set.copy()
random.shuffle(raw_set_new)
batch_set = []
batch_pos = []
for _ in range(neg_sample_size+1):
for start_idx in range(0, len(raw_set_new)-subset_size, subset_size+1):
subset = raw_set_new[start_idx:start_idx+subset_size]
pos_inst = raw_set_new[start_idx+subset_size]
batch_set.append(subset)
batch_pos.append(pos_inst)
random.shuffle(raw_set_new)
pos_sip_cnt = int(len(batch_set) / (neg_sample_size+1))
pos_sip_cnt_sum += pos_sip_cnt
neg_sip_cnt = int(pos_sip_cnt * neg_sample_size)
neg_sip_cnt_sum += neg_sip_cnt
negative_pool = [ele for ele in self.vocab if ele not in raw_set]
sample_size = math.gcd(neg_sip_cnt, len(negative_pool))
sample_times = int(neg_sip_cnt / sample_size)
batch_neg = []
for _ in range(sample_times):
batch_neg.extend(random.sample(negative_pool, sample_size))
for idx, subset in enumerate(batch_set):
if idx < pos_sip_cnt:
pos = batch_pos[idx]
sip_triplets.append((subset, pos, 1))
else:
neg = batch_neg[idx-pos_sip_cnt]
sip_triplets.append((subset, neg, 0))
return sip_triplets, pos_sip_cnt_sum, neg_sip_cnt_sum
elif pos_strategy == "vary_size_enumerate_with_full_set":
sip_triplets = []
pos_sip_cnt_sum = 0
neg_sip_cnt_sum = 0
for subset_size in range(1, max_set_size+1):
for raw_set in raw_sets:
if len(raw_set) < subset_size:
continue
raw_set_new = raw_set.copy()
random.shuffle(raw_set_new)
batch_set = []
batch_pos = []
if len(raw_set) == subset_size: # put the entire full set
for _ in range(neg_sample_size+1):
batch_set.append(raw_set)
batch_pos.append(random.choice(raw_set))
else:
for _ in range(neg_sample_size+1):
for start_idx in range(0, len(raw_set_new)-subset_size, subset_size+1):
subset = raw_set_new[start_idx:start_idx+subset_size]
pos_inst = raw_set_new[start_idx+subset_size]
batch_set.append(subset)
batch_pos.append(pos_inst)
random.shuffle(raw_set_new)
pos_sip_cnt = int(len(batch_set) / (neg_sample_size+1))
pos_sip_cnt_sum += pos_sip_cnt
neg_sip_cnt = int(pos_sip_cnt * neg_sample_size)
neg_sip_cnt_sum += neg_sip_cnt
negative_pool = [ele for ele in self.vocab if ele not in raw_set]
sample_size = math.gcd(neg_sip_cnt, len(negative_pool))
sample_times = int(neg_sip_cnt / sample_size)
batch_neg = []
for _ in range(sample_times):
batch_neg.extend(random.sample(negative_pool, sample_size))
for idx, subset in enumerate(batch_set):
if idx < pos_sip_cnt:
pos = batch_pos[idx]
sip_triplets.append((subset, pos, 1))
else:
neg = batch_neg[idx-pos_sip_cnt]
sip_triplets.append((subset, neg, 0))
return sip_triplets, pos_sip_cnt_sum, neg_sip_cnt_sum
elif pos_strategy == "vary_size_enumerate_with_full_set_plus_group_id":
sip_triplets = []
pos_sip_cnt_sum = 0
neg_sip_cnt_sum = 0
groups = []
for subset_size in range(1, max_set_size+1):
for group_id, raw_set in enumerate(raw_sets):
if len(raw_set) < subset_size:
continue
raw_set_new = raw_set.copy()
random.shuffle(raw_set_new)
batch_set = []
batch_pos = []
if len(raw_set) == subset_size: # put the entire full set
for _ in range(neg_sample_size+1):
batch_set.append(raw_set)
batch_pos.append(random.choice(raw_set))
else:
for _ in range(neg_sample_size+1):
for start_idx in range(0, len(raw_set_new)-subset_size, subset_size+1):
subset = raw_set_new[start_idx:start_idx+subset_size]
pos_inst = raw_set_new[start_idx+subset_size]
batch_set.append(subset)
batch_pos.append(pos_inst)
random.shuffle(raw_set_new)
pos_sip_cnt = int(len(batch_set) / (neg_sample_size+1))
pos_sip_cnt_sum += pos_sip_cnt
neg_sip_cnt = int(pos_sip_cnt * neg_sample_size)
neg_sip_cnt_sum += neg_sip_cnt
negative_pool = [ele for ele in self.vocab if ele not in raw_set]
sample_size = math.gcd(neg_sip_cnt, len(negative_pool))
sample_times = int(neg_sip_cnt / sample_size)
batch_neg = []
for _ in range(sample_times):
batch_neg.extend(random.sample(negative_pool, sample_size))
for idx, subset in enumerate(batch_set):
if idx < pos_sip_cnt:
pos = batch_pos[idx]
sip_triplets.append((subset, pos, 1))
groups.append(group_id)
else:
neg = batch_neg[idx-pos_sip_cnt]
sip_triplets.append((subset, neg, 0))
groups.append(group_id)
return sip_triplets, pos_sip_cnt_sum, neg_sip_cnt_sum, groups
elif pos_strategy == "enumerate":
sip_triplets = []
pos_sip_cnt_sum = 0
neg_sip_cnt_sum = 0
for r in range(1, max_set_size + 1):
for positive_full_set in raw_sets:
if len(positive_full_set) < r + 1: # unable to sample a set of size r
continue
negative_pool = [ele for ele in self.vocab if ele not in positive_full_set]
subsets = [] # cache subsets
pos_insts = [] # cache positive instance
for _ in range(neg_sample_size + 1):
for start_idx in range(0, len(positive_full_set) - r, (r + 1)):
subset = positive_full_set[start_idx:start_idx + r]
pos_inst = positive_full_set[start_idx + r]
subsets.append(subset)
pos_insts.append(pos_inst)
random.shuffle(positive_full_set)
pos_pairs_cnt = int(len(subsets) / (neg_sample_size + 1))
neg_pairs_cnt = int(pos_pairs_cnt * neg_sample_size)
pos_sip_cnt_sum += pos_pairs_cnt
neg_sip_cnt_sum += neg_pairs_cnt
sample_size = math.gcd(neg_pairs_cnt, len(negative_pool))
sample_times = int(neg_pairs_cnt / sample_size)
neg_insts = []
for _ in range(sample_times):
neg_insts.extend(random.sample(negative_pool, sample_size))
for idx, subset in enumerate(subsets):
if idx < pos_pairs_cnt:
pos = pos_insts[idx]
sip_triplets.append([subset, pos, 1])
else:
neg = neg_insts[idx - pos_pairs_cnt]
sip_triplets.append([subset, neg, 0])
return sip_triplets, pos_sip_cnt_sum, neg_sip_cnt_sum
[docs] def _convert_sip_format_to_tensor(self, max_set_size, batch_set, batch_inst, labels):
""" Generate tensors for <set, instance> pairs
:param max_set_size: maximum size of "set" in <set, instance> pairs
:type max_set_size: int
:param batch_set: a list of "sets" in <set, instance> pairs
:type batch_set: list
:param batch_inst: a list of "instances" in <set, instance> pairs
:type batch_inst: list
:param labels: a list of labels for each above <set, instance> pair
:type labels: list
:return: a dict of pytorch tensors representing <set, instance> pairs with their corresponding labels
:rtype: dict
"""
batch_size = len(batch_set)
batch_set_tensor = np.zeros([batch_size, max_set_size], dtype=np.int)
for row_id, row in enumerate(batch_set):
if len(row) > max_set_size:
batch_set_tensor[row_id][:] = row[:max_set_size]
else:
batch_set_tensor[row_id][:len(row)] = row
batch_set_tensor = torch.from_numpy(batch_set_tensor) # (batch_size, max_set_size)
batch_inst_tensor = torch.tensor(batch_inst) # (batch_size, )
batch_inst_tensor.unsqueeze_(1) # (batch_size, 1)
label_tensor = torch.tensor(labels).unsqueeze(1)
return {'set': batch_set_tensor.to(self.device), 'inst': batch_inst_tensor.to(self.device),
'label': label_tensor.to(self.device)}
[docs] def _generate_negative_samples_within_pool(self, positive_sets, neg_sample_size, remove_pos=True):
""" Generate negative samples from vocabulary
:param positive_sets: a list of positive sets
:type positive_sets: list
:param neg_sample_size: negative sampling ratio
:type neg_sample_size: int
:param remove_pos: whether to remove instances in positive sets from the vocabulary
:type remove_pos: bool
:return: a list of negative sets
:rtype: list
"""
batch_neg = []
for positive_set in positive_sets:
if remove_pos:
sample_pool = [ele for ele in self.vocab if ele not in positive_set]
else:
sample_pool = self.vocab
if neg_sample_size <= len(sample_pool):
neg = random.sample(sample_pool, neg_sample_size)
else:
repeat_time = int(neg_sample_size / len(sample_pool))
neg = sample_pool.copy() * repeat_time
remaining_num = neg_sample_size - len(sample_pool) * repeat_time
neg += random.sample(sample_pool, remaining_num)
batch_neg.append(neg)
return batch_neg