Skip to content
Snippets Groups Projects
Commit c6929e71 authored by MARMORET Axel's avatar MARMORET Axel
Browse files

Adding some warnings

parent 208eceb1
No related branches found
No related tags found
No related merge requests found
......@@ -7,13 +7,15 @@ Created on Tue Jun 11 15:49:25 2019
import numpy as np
import time
import math
import warnings
import nn_fac.update_rules.nnls as nnls
import nn_fac.update_rules.mu as mu
import nn_fac.utils.beta_divergence as beta_div
import nn_fac.utils.errors as err
from nimfa.methods import seeding
import math
from nimfa.methods import seeding
def nmf(data, rank, init = "random", U_0 = None, V_0 = None, n_iter_max=100, tol=1e-8,
update_rule = "hals", beta = 2,
......@@ -171,6 +173,11 @@ def nmf(data, rank, init = "random", U_0 = None, V_0 = None, n_iter_max=100, tol
Learning the parts of objects by non-negative matrix factorization.
Nature, 401(6755), 788-791.
"""
if min(data.shape) < rank:
min_data = min(data.shape)
rank = min_data
warnings.warn(f"The rank is too high for the input matrix. It was set to {min_data} instead.")
if init.lower() == "random":
k, n = data.shape
if deterministic:
......@@ -182,6 +189,8 @@ def nmf(data, rank, init = "random", U_0 = None, V_0 = None, n_iter_max=100, tol
V_0 = np.random.rand(rank, n)
elif init.lower() == "nndsvd":
with warnings.catch_warnings():
warnings.simplefilter("ignore") # A warning arises from the nimfa toolbox, because of the sue of np.asmatrix.
U_0, V_0 = seeding.Nndsvd().initialize(data, rank, {'flag': 0})
U_0 = np.array(U_0 + 1e-12)
V_0 = np.array(V_0 + 1e-12)
......
......@@ -8,6 +8,8 @@ Created on Tue Jun 11 16:52:21 2019
import numpy as np
import time
import tensorly as tl
import warnings
from nimfa.methods import seeding
import nn_fac.update_rules.nnls as nnls
......@@ -193,6 +195,8 @@ def ntf(tensor, rank, init = "random", factors_0 = [], n_iter_max=100, tol=1e-8,
if tensor.shape[mode] < rank:
current_factor = np.random.rand(tensor.shape[mode], rank)
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore") # A warning arises from the nimfa toolbox, because of the sue of np.asmatrix.
current_factor, useless_variable = seeding.Nndsvd().initialize(tl.unfold(tensor, mode), rank, {'flag': 0})
factors.append(tl.tensor(current_factor))
......
......@@ -5,6 +5,7 @@ Created on Tue Jun 11 17:12:33 2019
@author: amarmore
"""
import warnings
import numpy as np
import time
import nn_fac.update_rules.nnls as nnls
......@@ -194,6 +195,8 @@ def parafac_2(tensor_slices, rank, init_with_P, init = "random", W_list_in = Non
elif init.lower() == "nndsvd":
for k in range(nb_channel):
with warnings.catch_warnings():
warnings.simplefilter("ignore") # A warning arises from the nimfa toolbox, because of the sue of np.asmatrix.
W_k, H = seeding.Nndsvd().initialize(tensor_slices[k], rank, {'flag': 0})
W_list.append(W_k)
D_list.append(np.diag(np.random.rand(rank)))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment