Source code for chimeranet.reconstruction


import numpy as np
import sklearn

"""
input: NxTxFxD tensor
output: NxCxFxT tensor
"""
[docs]def from_embedding(embedding, n_channels, n_jobs=-1): embedding_dim = embedding.shape[-1] labels = sklearn.cluster.KMeans( n_clusters=n_channels, n_jobs=n_jobs ).fit( embedding.reshape(embedding.size // embedding_dim, embedding_dim) ).labels_ mask = np.eye(n_channels)[labels]\ .reshape(list(embedding.shape[:-1])+[n_channels])\ .transpose((0, 3, 2, 1)) return mask
""" input: NxTxFxC tensor output: NxCxFxT tensor """
[docs]def from_mask(mask): return mask.transpose((0, 3, 2, 1))