Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
  • master
  • v0.2.0
  • v0.2.1
3 results

Target

Select target project
  • a23marmo/nonnegative-factorization
1 result
Select Git revision
  • master
  • v0.2.0
  • v0.2.1
3 results
Show changes
Commits on Source (2)
......@@ -477,7 +477,3 @@ if __name__ == "__main__":
W, H = nmf(data, rank, beta = 0, update_rule = "mu", n_iter_max = 100, init="random", verbose = True)
W, H = nmf(data, rank, beta = 0, update_rule = "mu", n_iter_max = 100, init = "nndsvd",verbose = True)
# TO DEBUG
# W, H, lossfun, t = minvol_beta_nmf(data, rank, beta = 1, n_iter_max = 100, init = "nndsvd", gamma_line_search = True, gamma = 1, verbose = True)
# W, H, lossfun, t = minvol_beta_nmf(data, rank, beta = 1, n_iter_max = 100, i gamma_line_search = False, verbose = True)
import warnings
import numpy as np
from nn_fac.utils.normalize_wh import normalize_WH
import nn_fac.utils.normalize_wh as normalize_wh
from nn_fac.utils.beta_divergence import beta_divergence
eps = 1e-12
......@@ -27,7 +27,7 @@ def KL_mu_min_vol(data, W, H, delta, lambda_, gamma = None, tol_update_lagrangia
else:
lagragian_multipliers_0 = np.zeros((k, 1)) #(D[:,0] - C[:,0] * W[:,0]).T
lagragian_multipliers = update_lagragian_multipliers(C, S, D, W, lagragian_multipliers_0, tol_update_lagrangian)
lagragian_multipliers = update_lagragian_multipliers_Wminvol(C, S, D, W, lagragian_multipliers_0, tol_update_lagrangian)
W = W * ((((C + Jm1 @ lagragian_multipliers.T) ** 2 + S) ** 0.5 - (C + Jm1 @ lagragian_multipliers.T)) / (D + eps))
......@@ -44,15 +44,15 @@ def gamma_line_search(data, W_update, W_gamma_init, H_gamma_init, beta, delta, g
while cur_err > prev_error and gamma > 1e-16:
gamma *= 0.8
W_gamma = (1 - gamma) * W_prev + gamma * W_update
W_gamma, H_gamma = normalize_WH(W_gamma, H_gamma, "W")
W_gamma, H_gamma = normalize_wh.normalize_WH(W_gamma, H_gamma, "W")
cur_log_det = compute_log_det(W_gamma, delta)
cur_err = beta_divergence(data, W_gamma @ H_gamma, beta) + lambda_tilde * cur_log_det
gamma = min(gamma * 1.2, 1)
return W_gamma, H_gamma, gamma
def update_lagragian_multipliers(C, S, D, W, lagrangian_multipliers_0, tol = 1e-6, n_iter_max = 1000):
# Comes from
def update_lagragian_multipliers_Wminvol(C, S, D, W, lagrangian_multipliers_0, tol = 1e-6, n_iter_max = 100):
# Comes from Multiplicative Updates for NMF with β-Divergences under Disjoint Equality Constraints, https://arxiv.org/pdf/2010.16223.pdf
m, k = W.shape
Jm1 = np.ones((m,1))
Jk1 = np.ones(k)
......@@ -74,7 +74,7 @@ def update_lagragian_multipliers(C, S, D, W, lagrangian_multipliers_0, tol = 1e-
break
if iter == n_iter_max - 1:
warnings.warn('Maximum of iterations reached in the update of mu.')
warnings.warn('Maximum of iterations reached in the update of the Lagrangian multipliers.')
return lagrangian_multipliers
......