Skip to content
Snippets Groups Projects
Commit f1c328f3 authored by MARMORET Axel's avatar MARMORET Axel
Browse files

Minor changes (name changing mainly)

parent 1ec1475c
No related branches found
No related tags found
No related merge requests found
import warnings
import numpy as np
from nn_fac.utils.normalize_wh import normalize_WH
import nn_fac.utils.normalize_wh as normalize_wh
from nn_fac.utils.beta_divergence import beta_divergence
eps = 1e-12
......@@ -27,7 +27,7 @@ def KL_mu_min_vol(data, W, H, delta, lambda_, gamma = None, tol_update_lagrangia
else:
lagragian_multipliers_0 = np.zeros((k, 1)) #(D[:,0] - C[:,0] * W[:,0]).T
lagragian_multipliers = update_lagragian_multipliers(C, S, D, W, lagragian_multipliers_0, tol_update_lagrangian)
lagragian_multipliers = update_lagragian_multipliers_Wminvol(C, S, D, W, lagragian_multipliers_0, tol_update_lagrangian)
W = W * ((((C + Jm1 @ lagragian_multipliers.T) ** 2 + S) ** 0.5 - (C + Jm1 @ lagragian_multipliers.T)) / (D + eps))
......@@ -44,15 +44,15 @@ def gamma_line_search(data, W_update, W_gamma_init, H_gamma_init, beta, delta, g
while cur_err > prev_error and gamma > 1e-16:
gamma *= 0.8
W_gamma = (1 - gamma) * W_prev + gamma * W_update
W_gamma, H_gamma = normalize_WH(W_gamma, H_gamma, "W")
W_gamma, H_gamma = normalize_wh.normalize_WH(W_gamma, H_gamma, "W")
cur_log_det = compute_log_det(W_gamma, delta)
cur_err = beta_divergence(data, W_gamma @ H_gamma, beta) + lambda_tilde * cur_log_det
gamma = min(gamma * 1.2, 1)
return W_gamma, H_gamma, gamma
def update_lagragian_multipliers(C, S, D, W, lagrangian_multipliers_0, tol = 1e-6, n_iter_max = 1000):
# Comes from
def update_lagragian_multipliers_Wminvol(C, S, D, W, lagrangian_multipliers_0, tol = 1e-6, n_iter_max = 100):
# Comes from Multiplicative Updates for NMF with β-Divergences under Disjoint Equality Constraints, https://arxiv.org/pdf/2010.16223.pdf
m, k = W.shape
Jm1 = np.ones((m,1))
Jk1 = np.ones(k)
......@@ -74,7 +74,7 @@ def update_lagragian_multipliers(C, S, D, W, lagrangian_multipliers_0, tol = 1e-
break
if iter == n_iter_max - 1:
warnings.warn('Maximum of iterations reached in the update of mu.')
warnings.warn('Maximum of iterations reached in the update of the Lagrangian multipliers.')
return lagrangian_multipliers
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment