Source code for cluster_predict

"""
.. module:: cluster_predict
    :synopsis: clustering by set generation algorithm

.. moduleauthor:: Jiaming Shen
"""
import torch
import numpy as np
from tqdm import tqdm


[docs]def multiple_set_single_instance_prediction(model, sets, instance, size_optimized=False): """ Apply the given model to predict the probabilities of adding that one instance into each of the given sets :param model: a trained SynSetMine model :type model: SSPM :param sets: a list of sets, each contain the element index :type sets: list :param instance: a single instance, represented by the element index :type instance: int :param size_optimized: whether to optimize the multiple-set-single-instance prediction process. If the size of each set in the given 'sets' varies a lot and there exists a single huge set in the given 'sets', set this parameter to be True :type size_optimized: bool :return: - scores of given sets, (batch_size, 1) - scores of given sets union with the instance, (batch_size, 1) - the probability of adding the instance into the corresponding set, (batch_size, 1) :rtype: tuple """ if not size_optimized: # when there exists no single big cluster, no need for complex size optimization return _multiple_set_single_instance_prediction(model, sets, instance) else: if len(sets) <= 10: return _multiple_set_single_instance_prediction(model, sets, instance) set_sizes = [len(ele) for ele in sets] tmp = sorted(enumerate(set_sizes), key=lambda x: x[1]) # (old index, set_size) n2o = {n: ele[0] for n, ele in enumerate(tmp)} # new index -> old index o2n = {n2o[n]: n for n in n2o} # old index -> new index sorted_set_sizes = [ele[1] for ele in tmp] # the bining method is a combination of 'sturges' and 'fd' estimators, another choice is set "bins="sturges", which generates more bins # c.f.: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram_bin_edges.html#numpy.histogram_bin_edges _, bin_edges = np.histogram(sorted_set_sizes, bins="auto") inds = np.digitize(sorted_set_sizes, bin_edges) sorted_setScores = [] sorted_setInstSumScores = [] sorted_positive_prob = [] cur_ind = inds[0] cur_set = [sets[tmp[0][0]]] for i in range(1, len(inds)): if inds[i] == cur_ind: cur_set.append(sets[tmp[i][0]]) else: cur_setScores, cur_setInstSumScores, cur_positive_prob = _multiple_set_single_instance_prediction( model, cur_set, instance) sorted_setScores += cur_setScores sorted_setInstSumScores += cur_setInstSumScores sorted_positive_prob += cur_positive_prob cur_ind = inds[i] cur_set = [sets[tmp[i][0]]] if len(cur_set) > 0: # working on the last bin cur_setScores, cur_setInstSumScores, cur_positive_prob = _multiple_set_single_instance_prediction( model, cur_set, instance) sorted_setScores += cur_setScores sorted_setInstSumScores += cur_setInstSumScores sorted_positive_prob += cur_positive_prob if len(sets) != len(sorted_positive_prob): assert "Mismatch after binning optimization" setScores = [] setInstSumScores = [] positive_prob = [] for o in range(len(sets)): setScores.append(sorted_setScores[o2n[o]]) setInstSumScores.append(sorted_setInstSumScores[o2n[o]]) positive_prob.append(sorted_positive_prob[o2n[o]]) return setScores, setInstSumScores, positive_prob
def _multiple_set_single_instance_prediction(model, sets, instance): model.eval() # generate tensors batch_size = len(sets) max_set_size = max([len(ele) for ele in sets]) batch_set_tensor = np.zeros([batch_size, max_set_size], dtype=np.int) for row_id, row in enumerate(sets): batch_set_tensor[row_id][:len(row)] = row batch_set_tensor = torch.from_numpy(batch_set_tensor) # (batch_size, max_set_size) batch_inst_tensor = torch.tensor(instance).unsqueeze(0).expand(batch_size, 1) # (batch_size, 1) batch_set_tensor = batch_set_tensor.to(model.device) batch_inst_tensor = batch_inst_tensor.to(model.device) # inference setScores, setInstSumScores, prediction = model.predict(batch_set_tensor, batch_inst_tensor) # convert to probability of each sip positive_prob = prediction.squeeze(-1).detach() positive_prob = list(positive_prob.to(torch.device("cpu")).numpy()) setScores = setScores.squeeze(-1).detach() setScores = list(setScores.to(torch.device("cpu")).numpy()) setInstSumScores = setInstSumScores.squeeze(-1).detach() setInstSumScores = list(setInstSumScores.to(torch.device("cpu")).numpy()) model.train() return setScores, setInstSumScores, positive_prob
[docs]def set_generation(model, vocab, threshold=0.5, eid2ename=None, size_opt_clus=False, max_K=None, verbose=False): """ Set Generation Algorithm :param model: a trained set-instance classifier :type model: SSPM :param vocab: a list of elements to be clustered, each element is represented by its index :type vocab: list :param threshold: the probability threshold for determine whether to create new singleton cluster :type threshold: float :param eid2ename: a dictionary mapping element index to its corresponding (human-readable) name :type eid2ename: dict :param size_opt_clus: a flag indicating whether to optimize the multiple-set-single-instance prediction process :type size_opt_clus: bool :param max_K: maximum number of clusters, If None, we will infer this number automatically :type max_K: int :param verbose: whether to print out all intermediate results :type verbose: bool :return: a list of detected clusters :rtype: list """ model.eval() clusters = [] # will be a list of lists candidate_pool = vocab if verbose: print("{}\t{}".format("vocab", [eid2ename[eid] for eid in vocab])) if verbose: g = tqdm(range(len(candidate_pool)), desc="Cluster prediction (aggressive one pass)...") else: g = range(len(candidate_pool)) for i in g: inst = candidate_pool[i] if i == 0: cluster = [inst] clusters.append(cluster) else: setScores, setInstSumScores, cluster_probs = multiple_set_single_instance_prediction( model, clusters, inst, size_optimized=size_opt_clus ) best_matching_existing_cluster_idx = -1 best_matching_existing_cluster_prob = 0.0 for cid, cluster_prob in enumerate(cluster_probs): if cluster_prob > best_matching_existing_cluster_prob: best_matching_existing_cluster_prob = cluster_prob best_matching_existing_cluster_idx = cid if verbose: print("Current Cluster Pool:", [(cid, [eid2ename[ele] for ele in cluster]) for cid, cluster in enumerate(clusters)]) print("-" * 20) print("Entity: {:<30} best_prob = {:<8} Best-matching Cluster: {:<80} (cid={})".format(eid2ename[inst], best_matching_existing_cluster_prob, str( [eid2ename[eid] for eid in clusters[best_matching_existing_cluster_idx]]), best_matching_existing_cluster_idx)) if max_K and len(clusters) >= max_K: clusters[best_matching_existing_cluster_idx].append(inst) if verbose: print("!!! Add Entity In") else: # then either add this instance to existing cluster or create a new cluster if best_matching_existing_cluster_prob > threshold: clusters[best_matching_existing_cluster_idx].append(inst) if verbose: print("!!! Add Entity In") else: new_cluster = [inst] clusters.append(new_cluster) if verbose: print("-" * 120) model.train() return clusters