diff --git a/.gitignore b/.gitignore
index bee1ff36b3bcf32022297bafb2a085d2c4021f70..ee90c321159d6133a1f0a2dd2e45ec49b03b8d40 100644
--- a/.gitignore
+++ b/.gitignore
@@ -146,7 +146,7 @@ nohup.out
 ## vscode config
 .vscode/
 
-.test.py
+test.py
 
 ## directories that outputs when running the tests
 tests/PLN*
diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index b99960db44d8ba2a5e74108f68a44f3bc6e2bffd..dbeb387706431aee3182104bd321fda1d0e9d015 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -46,4 +46,4 @@ pages:
   tags:
     - docker
   only:
-    - main
+    - tags
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 3eb5a9736d0da2592327bc72aa3dcef41ce9fb8d..71513b20f7628bf8882b7a15cc63dd205e25516d 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -12,8 +12,8 @@ Welcome to pyPLNmodels's documentation!
 
    ./module.rst
 
-Indices and tables working again
-================================
+Indices and tables working once again
+=====================================
 
 * :ref:`genindex`
 * :ref:`modindex`
diff --git a/docs/source/module.rst b/docs/source/module.rst
index 752288b1381feedf2d6b53f398c4c006cbafc04b..7c22fce646290f47f5a58f3e77d18f2dc513dda4 100644
--- a/docs/source/module.rst
+++ b/docs/source/module.rst
@@ -11,3 +11,5 @@ API documentation
 .. autoclass:: pyPLNmodels.PLNPCA
    :members:
    :show-inheritance:
+   :special-members: __init__
+   :undoc-members:
diff --git a/pyPLNmodels/_closed_forms.py b/pyPLNmodels/_closed_forms.py
index 8c0ec8d200e1e27b89f5bdf7a30b8047ae07b8e7..1405b6df84a402fddf35bfc77f7c5f587c11c9e6 100644
--- a/pyPLNmodels/_closed_forms.py
+++ b/pyPLNmodels/_closed_forms.py
@@ -1,28 +1,100 @@
+from typing import Optional
+
 import torch  # pylint:disable=[C0114]
 
 
-def _closed_formula_covariance(covariates, latent_mean, latent_var, coef, n_samples):
-    """Closed form for covariance for the M step for the noPCA model."""
+def _closed_formula_covariance(
+    covariates: torch.Tensor,
+    latent_mean: torch.Tensor,
+    latent_var: torch.Tensor,
+    coef: torch.Tensor,
+    n_samples: int,
+) -> torch.Tensor:
+    """
+    Compute the closed-form covariance for the M step of the noPCA model.
+
+    Parameters:
+    ----------
+    covariates : torch.Tensor
+        Covariates with size (n, d).
+    latent_mean : torch.Tensor
+        Variational parameter with size (n, p).
+    latent_var : torch.Tensor
+        Variational parameter with size (n, p).
+    coef : torch.Tensor
+        Model parameter with size (d, p).
+    n_samples : int
+        Number of samples (n).
+
+    Returns:
+    -------
+    torch.Tensor
+        The closed-form covariance with size (p, p).
+    """
     if covariates is None:
         XB = 0
     else:
         XB = covariates @ coef
-    m_moins_xb = latent_mean - XB
-    closed = m_moins_xb.T @ m_moins_xb + torch.diag(
+    m_minus_xb = latent_mean - XB
+    closed = m_minus_xb.T @ m_minus_xb + torch.diag(
         torch.sum(torch.square(latent_var), dim=0)
     )
     return closed / n_samples
 
 
-def _closed_formula_coef(covariates, latent_mean):
-    """Closed form for coef for the M step for the noPCA model."""
+def _closed_formula_coef(
+    covariates: torch.Tensor, latent_mean: torch.Tensor
+) -> Optional[torch.Tensor]:
+    """
+    Compute the closed-form coef for the M step of the noPCA model.
+
+    Parameters:
+    ----------
+    covariates : torch.Tensor
+        Covariates with size (n, d).
+    latent_mean : torch.Tensor
+        Variational parameter with size (n, p).
+
+    Returns:
+    -------
+    Optional[torch.Tensor]
+        The closed-form coef with size (d, p) or None if covariates is None.
+    """
     if covariates is None:
         return None
     return torch.inverse(covariates.T @ covariates) @ covariates.T @ latent_mean
 
 
 def _closed_formula_pi(
-    offsets, latent_mean, latent_var, dirac, covariates, _coef_inflation
-):
+    offsets: torch.Tensor,
+    latent_mean: torch.Tensor,
+    latent_var: torch.Tensor,
+    dirac: torch.Tensor,
+    covariates: torch.Tensor,
+    _coef_inflation: torch.Tensor,
+) -> torch.Tensor:
+    """
+    Compute the closed-form pi for the M step of the noPCA model.
+
+    Parameters:
+    ----------
+    offsets : torch.Tensor
+        Offset with size (n, p).
+    latent_mean : torch.Tensor
+        Variational parameter with size (n, p).
+    latent_var : torch.Tensor
+        Variational parameter with size (n, p).
+    dirac : torch.Tensor
+        Dirac tensor.
+    covariates : torch.Tensor
+        Covariates with size (n, d).
+    _coef_inflation : torch.Tensor
+        Inflation coefficient tensor.
+
+    Returns:
+    -------
+    torch.Tensor
+        The closed-form pi with the same size as dirac.
+    """
     poiss_param = torch.exp(offsets + latent_mean + 0.5 * torch.square(latent_var))
     return torch._sigmoid(poiss_param + torch.mm(covariates, _coef_inflation)) * dirac
diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index 3a8493f3868bf28d6dec286c9aa6927b1efab43b..4b977e0a830375ea43417b979bc5bdbe126f070a 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -1,15 +1,16 @@
-import math  # pylint:disable=[C0114]
-import warnings
 import os
-
-import matplotlib.pyplot as plt
+import math
+import warnings
 import numpy as np
+import pandas as pd
 import torch
 import torch.linalg as TLA
-import pandas as pd
-from matplotlib.patches import Ellipse
 from matplotlib import transforms
+from matplotlib.patches import Ellipse
+import matplotlib.pyplot as plt
 from patsy import dmatrices
+from typing import Optional, Dict, Any, Union
+
 
 torch.set_default_dtype(torch.float64)
 
@@ -20,35 +21,45 @@ else:
 
 
 class _PlotArgs:
-    def __init__(self, window):
+    def __init__(self, window: int):
+        """
+        Initialize the PlotArgs class.
+
+        Parameters
+        ----------
+        window : int
+            The size of the window for running statistics.
+        """
         self.window = window
         self.running_times = []
         self.criterions = [1] * window
         self._elbos_list = []
 
     @property
-    def iteration_number(self):
+    def iteration_number(self) -> int:
+        """
+        Get the number of iterations.
+
+        Returns
+        -------
+        int
+            The number of iterations.
+        """
         return len(self._elbos_list)
 
     def _show_loss(self, ax=None, name_doss=""):
-        """Show the ELBO of the algorithm along the iterations.
-
-        args:
-            'ax': AxesSubplot object. The ELBO will be displayed in this ax
-                if not None. If None, will simply create an axis. Default
-                is None.
-            'name_file': str. The name of the file the graphic
-                will be saved to.
-                Default is 'fastPLNPCA_ELBO'.
-        returns: None but displays the ELBO.
         """
-        if ax is None:
-            ax = plt.gca()
-        ax.plot(
-            self.running_times,
-            -np.array(self._elbos_list),
-            label="Negative ELBO",
-        )
+        Show the loss plot.
+
+        Parameters
+        ----------
+        ax : matplotlib.axes.Axes, optional
+            The axes object to plot on. If not provided, the current axes will be used.
+        name_doss : str, optional
+            The name of the loss. Default is an empty string.
+        """
+        ax = plt.gca() if ax is None else ax
+        ax.plot(self.running_times, -np.array(self._elbos_list), label="Negative ELBO")
         last_elbos = np.round(self._elbos_list[-1], 6)
         ax.set_title(f"Negative ELBO. Best ELBO ={last_elbos}")
         ax.set_yscale("log")
@@ -56,21 +67,16 @@ class _PlotArgs:
         ax.set_ylabel("ELBO")
         ax.legend()
 
-    def _show_stopping_criteration(self, ax=None):
-        """Show the criterion of the algorithm along the iterations.
-
-        args:
-            'ax': AxesSubplot object. The criterion will be displayed
-                in this ax
-                if not None. If None, will simply create an axis.
-                Default is None.
-            'name_file': str. The name of the file the graphic will
-                be saved to.
-                Default is 'fastPLN_criterion'.
-        returns: None but displays the criterion.
+    def _show_stopping_criterion(self, ax=None):
         """
-        if ax is None:
-            ax = plt.gca()
+        Show the stopping criterion plot.
+
+        Parameters
+        ----------
+        ax : matplotlib.axes.Axes, optional
+            The axes object to plot on. If not provided, the current axes will be used.
+        """
+        ax = plt.gca() if ax is None else ax
         ax.plot(
             self.running_times[self.window :],
             self.criterions[self.window :],
@@ -83,37 +89,61 @@ class _PlotArgs:
         ax.legend()
 
 
-def _init_covariance(counts, covariates, coef):
-    """Initialization for covariance for the PLN model. Take the log of counts
+def _init_covariance(
+    counts: torch.Tensor, covariates: torch.Tensor, coef: torch.Tensor
+) -> torch.Tensor:
+    """
+    Initialization for covariance for the PLN model. Take the log of counts
     (careful when counts=0), remove the covariates effects X@coef and
     then do as a MLE for Gaussians samples.
-    Args :
-            counts: torch.tensor. Samples with size (n,p)
-            0: torch.tensor. Offset, size (n,p)
-            covariates: torch.tensor. Covariates, size (n,d)
-            coef: torch.tensor of size (d,p)
-    Returns : torch.tensor of size (p,p).
+
+    Parameters
+    ----------
+    counts : torch.Tensor
+        Samples with size (n,p)
+    offsets : torch.Tensor
+        Offset, size (n,p)
+    covariates : torch.Tensor
+        Covariates, size (n,d)
+    coef : torch.Tensor
+        Coefficient of size (d,p)
+
+    Returns
+    -------
+    torch.Tensor
+        Covariance matrix of size (p,p)
     """
     log_y = torch.log(counts + (counts == 0) * math.exp(-2))
     log_y_centered = log_y - torch.mean(log_y, axis=0)
-    # MLE in a Gaussian setting
     n_samples = counts.shape[0]
     sigma_hat = 1 / (n_samples - 1) * (log_y_centered.T) @ log_y_centered
     return sigma_hat
 
 
-def _init_components(counts, covariates, coef, rank):
-    """Inititalization for components for the PLN model. Get a first
-    guess for covariance that is easier to estimate and then takes
-    the rank largest eigenvectors to get components.
-    Args :
-        counts: torch.tensor. Samples with size (n,p)
-        0: torch.tensor. Offset, size (n,p)
-        covarites: torch.tensor. Covariates, size (n,d)
-        coef: torch.tensor of size (d,p)
-        rank: int. The dimension of the latent space, i.e. the reducted dimension.
-    Returns :
-        torch.tensor of size (p,rank). The initialization of components.
+def _init_components(
+    counts: torch.Tensor, covariates: torch.Tensor, coef: torch.Tensor, rank: int
+) -> torch.Tensor:
+    """
+    Initialization for components for the PLN model. Get a first guess for covariance
+    that is easier to estimate and then takes the rank largest eigenvectors to get components.
+
+    Parameters
+    ----------
+    counts : torch.Tensor
+        Samples with size (n,p)
+    offsets : torch.Tensor
+        Offset, size (n,p)
+    covariates : torch.Tensor
+        Covariates, size (n,d)
+    coef : torch.Tensor
+        Coefficient of size (d,p)
+    rank : int
+        The dimension of the latent space, i.e. the reduced dimension.
+
+    Returns
+    -------
+    torch.Tensor
+        Initialization of components of size (p,rank)
     """
     sigma_hat = _init_covariance(counts, covariates, coef).detach()
     components = _components_from_covariance(sigma_hat, rank)
@@ -121,23 +151,42 @@ def _init_components(counts, covariates, coef, rank):
 
 
 def _init_latent_mean(
-    counts, covariates, offsets, coef, components, n_iter_max=500, lr=0.01, eps=7e-3
-):
-    """Initialization for the variational parameter M. Basically,
-    the mode of the log_posterior is computed.
-
-    Args:
-        counts: torch.tensor. Samples with size (n,p)
-        0: torch.tensor. Offset, size (n,p)
-        covariates: torch.tensor. Covariates, size (n,d)
-        coef: torch.tensor of size (d,p)
-        N_iter_max: int. The maximum number of iteration in
-            the gradient ascent.
-        lr: positive float. The learning rate of the optimizer.
-        eps: positive float, optional. The tolerance. The algorithm will stop
-            if the maximum of |W_t-W_{t-1}| is lower than eps, where W_t
-            is the t-th iteration of the algorithm.This parameter
-            changes a lot the resulting time of the algorithm. Default is 9e-3.
+    counts: torch.Tensor,
+    covariates: torch.Tensor,
+    offsets: torch.Tensor,
+    coef: torch.Tensor,
+    components: torch.Tensor,
+    n_iter_max=500,
+    lr=0.01,
+    eps=7e-3,
+) -> torch.Tensor:
+    """
+    Initialization for the variational parameter M. Basically, the mode of the log_posterior is computed.
+
+    Parameters
+    ----------
+    counts : torch.Tensor
+        Samples with size (n,p)
+    offsets : torch.Tensor
+        Offset, size (n,p)
+    covariates : torch.Tensor
+        Covariates, size (n,d)
+    coef : torch.Tensor
+        Coefficient of size (d,p)
+    components : torch.Tensor
+        Components of size (p,rank)
+    n_iter_max : int, optional
+        The maximum number of iterations in the gradient ascent. Default is 500.
+    lr : float, optional
+        The learning rate of the optimizer. Default is 0.01.
+    eps : float, optional
+        The tolerance. The algorithm will stop if the maximum of |W_t-W_{t-1}| is lower than eps,
+        where W_t is the t-th iteration of the algorithm. Default is 7e-3.
+
+    Returns
+    -------
+    torch.Tensor
+        The initialized latent mean with size (n,rank)
     """
     mode = torch.randn(counts.shape[0], components.shape[1], device=DEVICE)
     mode.requires_grad_(True)
@@ -160,69 +209,99 @@ def _init_latent_mean(
     return mode
 
 
-def _sigmoid(tens):
-    """Compute the _sigmoid function of x element-wise."""
+def _sigmoid(tens: torch.Tensor) -> torch.Tensor:
+    """
+    Compute the sigmoid function of x element-wise.
+
+    Parameters
+    ----------
+    tens : torch.Tensor
+        Input tensor
+
+    Returns
+    -------
+    torch.Tensor
+        Output tensor with sigmoid applied element-wise
+    """
     return 1 / (1 + torch.exp(-tens))
 
 
-def sample_pln(components, coef, covariates, offsets, _coef_inflation=None, seed=None):
-    """Sample Poisson log Normal variables. If _coef_inflation is not None, the model will
-    be zero inflated.
-
-    Args:
-        components: torch.tensor of size (p,rank). The matrix components of the PLN model
-        coef: torch.tensor of size (d,p). Regression parameter.
-        0: torch.tensor of size (n,p). Offsets.
-        covariates : torch.tensor of size (n,d). Covariates.
-        _coef_inflation: torch.tensor of size (d,p), optional. If _coef_inflation is not None,
-             the ZIPLN model is chosen, so that it will add a
-             Bernouilli layer. Default is None.
-    Returns :
-        counts: torch.tensor of size (n,p), the count variables.
-        Z: torch.tensor of size (n,p), the gaussian latent variables.
-        ksi: torch.tensor of size (n,p), the bernoulli latent variables
-        (full of zeros if _coef_inflation is None).
+def sample_pln(
+    components: torch.Tensor,
+    coef: torch.Tensor,
+    covariates: torch.Tensor,
+    offsets: torch.Tensor,
+    _coef_inflation: torch.Tensor = None,
+    seed: int = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+    """
+    Sample from the Poisson Log-Normal (PLN) model.
+
+    Parameters
+    ----------
+    components : torch.Tensor
+        Components of size (p, rank)
+    coef : torch.Tensor
+        Coefficient of size (d, p)
+    covariates : torch.Tensor or None
+        Covariates, size (n, d) or None
+    offsets : torch.Tensor
+        Offset, size (n, p)
+    _coef_inflation : torch.Tensor or None, optional
+        Coefficient for zero-inflation model, size (d, p) or None. Default is None.
+    seed : int or None, optional
+        Random seed for reproducibility. Default is None.
+
+    Returns
+    -------
+    tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+        Tuple containing counts (torch.Tensor), gaussian (torch.Tensor), and ksi (torch.Tensor)
     """
     prev_state = torch.random.get_rng_state()
     if seed is not None:
         torch.random.manual_seed(seed)
+
     n_samples = offsets.shape[0]
     rank = components.shape[1]
+
     if covariates is None:
         XB = 0
     else:
-        XB = covariates @ coef
+        XB = torch.matmul(covariates, coef)
+
     gaussian = torch.mm(torch.randn(n_samples, rank, device=DEVICE), components.T) + XB
     parameter = torch.exp(offsets + gaussian)
+
     if _coef_inflation is not None:
         print("ZIPLN is sampled")
-        zero_inflated_mean = covariates @ _coef_inflation
+        zero_inflated_mean = torch.matmul(covariates, _coef_inflation)
         ksi = torch.bernoulli(1 / (1 + torch.exp(-zero_inflated_mean)))
     else:
         ksi = 0
+
     counts = (1 - ksi) * torch.poisson(parameter)
+
     torch.random.set_rng_state(prev_state)
     return counts, gaussian, ksi
 
 
-# def logit(tens):
-#     """logit function. If x is too close from 1, we set the result to 0.
-#     performs logit element wise."""
-#     return torch.nan_to_num(torch.log(x / (1 - tens)),
-# nan=0, neginf=0, posinf=0)
-
-
-def _components_from_covariance(covariance, rank):
-    """Get the best matrix of size (p,rank) when covariance is of
-    size (p,p). i.e. reduces norm(covariance-components@components.T)
-    Args :
-        covariance: torch.tensor of size (p,p). Should be positive definite and
-            symmetric.
-        rank: int. The number of columns wanted for components
-
-    Returns:
-        components_reduct: torch.tensor of size (p,rank) containing the rank eigenvectors with
-        largest eigenvalues.
+def _components_from_covariance(covariance: torch.Tensor, rank: int) -> torch.Tensor:
+    """
+    Get the best matrix of size (p, rank) when covariance is of size (p, p),
+    i.e., reduce norm(covariance - components @ components.T).
+
+    Parameters
+    ----------
+    covariance : torch.Tensor
+        Covariance matrix of size (p, p)
+    rank : int
+        The number of columns wanted for components
+
+    Returns
+    -------
+    torch.Tensor
+        Requested components of size (p, rank) containing the rank eigenvectors
+        with largest eigenvalues.
     """
     eigenvalues, eigenvectors = TLA.eigh(covariance)
     requested_components = eigenvectors[:, -rank:] @ torch.diag(
@@ -231,49 +310,97 @@ def _components_from_covariance(covariance, rank):
     return requested_components
 
 
-def _init_coef(counts, covariates, offsets):
+def _init_coef(
+    counts: torch.Tensor, covariates: torch.Tensor, offsets: torch.Tensor
+) -> torch.Tensor:
+    """
+    Initialize the coefficient for the Poisson regression model.
+
+    Parameters
+    ----------
+    counts : torch.Tensor
+        Samples with size (n, p)
+    covariates : torch.Tensor
+        Covariates, size (n, d)
+    offsets : torch.Tensor
+        Offset, size (n, p)
+
+    Returns
+    -------
+    torch.Tensor or None
+        Coefficient of size (d, p) or None if covariates is None.
+    """
     if covariates is None:
         return None
+
     poiss_reg = _PoissonReg()
     poiss_reg.fit(counts, covariates, offsets)
     return poiss_reg.beta
 
 
-def _log_stirling(integer):
-    """Compute log(n!) even for n large. We use the Stirling formula to avoid
-    numerical infinite values of n!.
-    Args:
-         n: torch.tensor of any size.
-    Returns:
-        An approximation of log(n_!) element-wise.
+def _log_stirling(integer: torch.Tensor) -> torch.Tensor:
     """
-    integer_ = integer + (
-        integer == 0
-    )  # Replace the 0 with 1. It doesn't change anything since 0! = 1!
+    Compute log(n!) even for large n using the Stirling formula to avoid numerical
+    infinite values of n!.
+
+    Parameters
+    ----------
+    integer : torch.Tensor
+        Input tensor
+
+    Returns
+    -------
+    torch.Tensor
+        Approximation of log(n!) element-wise.
+    """
+    integer_ = integer + (integer == 0)  # Replace 0 with 1 since 0! = 1!
     return torch.log(torch.sqrt(2 * np.pi * integer_)) + integer_ * torch.log(
         integer_ / math.exp(1)
     )
 
 
-def log_posterior(counts, covariates, offsets, posterior_mean, components, coef):
-    """Compute the log posterior of the PLN model. Compute it either
-    for posterior_mean of size (N_samples, N_batch,rank) or (batch_size, rank). Need to have
-    both cases since it is done for both cases after. Please the mathematical
-    description of the package for the formula.
-    Args :
-        counts : torch.tensor of size (batch_size, p)
-        covariates : torch.tensor of size (batch_size, d) or (d)
-    Returns: torch.tensor of size (N_samples, batch_size) or (batch_size).
+def log_posterior(
+    counts: torch.Tensor,
+    covariates: torch.Tensor,
+    offsets: torch.Tensor,
+    posterior_mean: torch.Tensor,
+    components: torch.Tensor,
+    coef: torch.Tensor,
+) -> torch.Tensor:
+    """
+    Compute the log posterior of the Poisson Log-Normal (PLN) model.
+
+    Parameters
+    ----------
+    counts : torch.Tensor
+        Samples with size (batch_size, p)
+    covariates : torch.Tensor or None
+        Covariates, size (batch_size, d) or (d)
+    offsets : torch.Tensor
+        Offset, size (batch_size, p)
+    posterior_mean : torch.Tensor
+        Posterior mean with size (N_samples, N_batch, rank) or (batch_size, rank)
+    components : torch.Tensor
+        Components with size (p, rank)
+    coef : torch.Tensor
+        Coefficient with size (d, p)
+
+    Returns
+    -------
+    torch.Tensor
+        Log posterior of size (N_samples, batch_size) or (batch_size).
     """
     length = len(posterior_mean.shape)
     rank = posterior_mean.shape[-1]
     components_posterior_mean = torch.matmul(
         components.unsqueeze(0), posterior_mean.unsqueeze(2)
     ).squeeze()
+
     if covariates is None:
         XB = 0
     else:
-        XB = covariates @ coef
+        XB = torch.matmul(covariates, coef)
+
     log_lambda = offsets + components_posterior_mean + XB
     first_term = (
         -rank / 2 * math.log(2 * math.pi)
@@ -285,19 +412,72 @@ def log_posterior(counts, covariates, offsets, posterior_mean, components, coef)
     return first_term + second_term
 
 
-def _trunc_log(tens, eps=1e-16):
+def _trunc_log(tens: torch.Tensor, eps: float = 1e-16) -> torch.Tensor:
+    """
+    Compute the truncated logarithm of the input tensor.
+
+    Parameters
+    ----------
+    tens : torch.Tensor
+        Input tensor
+    eps : float, optional
+        Truncation value, default is 1e-16
+
+    Returns
+    -------
+    torch.Tensor
+        Truncated logarithm of the input tensor.
+    """
     integer = torch.min(torch.max(tens, torch.tensor([eps])), torch.tensor([1 - eps]))
     return torch.log(integer)
 
 
-def _get_offsets_from_sum_of_counts(counts):
+def _get_offsets_from_sum_of_counts(counts: torch.Tensor) -> torch.Tensor:
+    """
+    Compute offsets from the sum of counts.
+
+    Parameters
+    ----------
+    counts : torch.Tensor
+        Samples with size (n, p)
+
+    Returns
+    -------
+    torch.Tensor
+        Offsets of size (n, p)
+    """
     sum_of_counts = torch.sum(counts, axis=1)
     return sum_of_counts.repeat((counts.shape[1], 1)).T
 
 
 def _raise_wrong_dimension_error(
-    str_first_array, str_second_array, dim_first_array, dim_second_array, dim_of_error
-):
+    str_first_array: str,
+    str_second_array: str,
+    dim_first_array: int,
+    dim_second_array: int,
+    dim_of_error: int,
+) -> None:
+    """
+    Raise an error for mismatched dimensions between two tensors.
+
+    Parameters
+    ----------
+    str_first_array : str
+        Name of the first tensor
+    str_second_array : str
+        Name of the second tensor
+    dim_first_array : int
+        Dimension of the first tensor
+    dim_second_array : int
+        Dimension of the second tensor
+    dim_of_error : int
+        Dimension causing the error
+
+    Raises
+    ------
+    ValueError
+        If the dimensions of the two tensors do not match at the non-singleton dimension.
+    """
     msg = (
         f"The size of tensor {str_first_array} ({dim_first_array}) must match "
         f"the size of tensor {str_second_array} ({dim_second_array}) at "
@@ -307,8 +487,33 @@ def _raise_wrong_dimension_error(
 
 
 def _check_two_dimensions_are_equal(
-    str_first_array, str_second_array, dim_first_array, dim_second_array, dim_of_error
-):
+    str_first_array: str,
+    str_second_array: str,
+    dim_first_array: int,
+    dim_second_array: int,
+    dim_of_error: int,
+) -> None:
+    """
+    Check if two dimensions are equal.
+
+    Parameters
+    ----------
+    str_first_array : str
+        Name of the first array.
+    str_second_array : str
+        Name of the second array.
+    dim_first_array : int
+        Dimension of the first array.
+    dim_second_array : int
+        Dimension of the second array.
+    dim_of_error : int
+        Dimension of the error.
+
+    Raises
+    ------
+    ValueError
+        If the dimensions of the two arrays are not equal.
+    """
     if dim_first_array != dim_second_array:
         _raise_wrong_dimension_error(
             str_first_array,
@@ -319,19 +524,62 @@ def _check_two_dimensions_are_equal(
         )
 
 
-def _init_S(counts, covariates, offsets, beta, C, M):
+def _init_S(
+    counts: torch.Tensor,
+    covariates: torch.Tensor,
+    offsets: torch.Tensor,
+    beta: torch.Tensor,
+    C: torch.Tensor,
+    M: torch.Tensor,
+) -> torch.Tensor:
+    """
+    Initialize the S matrix.
+
+    Parameters
+    ----------
+    counts : torch.Tensor, shape (n, )
+        Count data.
+    covariates : torch.Tensor or None, shape (n, d) or None
+        Covariate data.
+    offsets : torch.Tensor or None, shape (n, ) or None
+        Offset data.
+    beta : torch.Tensor, shape (d, )
+        Beta parameter.
+    C : torch.Tensor, shape (r, d)
+        C parameter.
+    M : torch.Tensor, shape (r, k)
+        M parameter.
+
+    Returns
+    -------
+    torch.Tensor, shape (r, r)
+        Initialized S matrix.
+    """
     n, rank = M.shape
-    batch_matrix = torch.matmul(C.unsqueeze(2), C.unsqueeze(1)).unsqueeze(0)
-    CW = torch.matmul(C.unsqueeze(0), M.unsqueeze(2)).squeeze()
-    common = torch.exp(offsets + covariates @ beta + CW).unsqueeze(2).unsqueeze(3)
+    batch_matrix = torch.matmul(C[:, None, :], C[:, :, None])[None]
+    CW = torch.matmul(C[None], M[:, None, :]).squeeze()
+    common = torch.exp(offsets + covariates @ beta + CW)[:, None, None]
     prod = batch_matrix * common
-    hess_posterior = torch.sum(prod, axis=1) + torch.eye(rank).to(DEVICE)
+    hess_posterior = torch.sum(prod, dim=1) + torch.eye(rank, device=DEVICE)
     inv_hess_posterior = -torch.inverse(hess_posterior)
     hess_posterior = torch.diagonal(inv_hess_posterior, dim1=-2, dim2=-1)
     return hess_posterior
 
 
-def _format_data(data):
+def _format_data(data: pd.DataFrame) -> torch.Tensor or None:
+    """
+    Format the input data.
+
+    Parameters
+    ----------
+    data : pd.DataFrame, np.ndarray, or torch.Tensor
+        Input data.
+
+    Returns
+    -------
+    torch.Tensor or None
+        Formatted data.
+    """
     if data is None:
         return None
     if isinstance(data, pd.DataFrame):
@@ -345,8 +593,42 @@ def _format_data(data):
     )
 
 
-def _format_model_param(counts, covariates, offsets, offsets_formula, take_log_offsets):
+def _format_model_param(
+    counts: torch.Tensor,
+    covariates: torch.Tensor,
+    offsets: torch.Tensor,
+    offsets_formula: str,
+    take_log_offsets: bool,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+    """
+    Format the model parameters.
+
+    Parameters
+    ----------
+    counts : torch.Tensor or None, shape (n, )
+        Count data.
+    covariates : torch.Tensor or None, shape (n, d) or None
+        Covariate data.
+    offsets : torch.Tensor or None, shape (n, ) or None
+        Offset data.
+    offsets_formula : str
+        Formula for calculating offsets.
+    take_log_offsets : bool
+        Flag indicating whether to take the logarithm of offsets.
+
+    Returns
+    -------
+    Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+        Formatted model parameters.
+    Raises
+    ------
+    ValueError
+        If counts has negative values.
+
+    """
     counts = _format_data(counts)
+    if torch.min(counts) < 0:
+        raise ValueError("Counts should be only non negavtive values.")
     if covariates is not None:
         covariates = _format_data(covariates)
     if offsets is None:
@@ -364,7 +646,20 @@ def _format_model_param(counts, covariates, offsets, offsets_formula, take_log_o
     return counts, covariates, offsets
 
 
-def _remove_useless_intercepts(covariates):
+def _remove_useless_intercepts(covariates: torch.Tensor) -> torch.Tensor:
+    """
+    Remove useless intercepts from covariates.
+
+    Parameters
+    ----------
+    covariates : torch.Tensor, shape (n, d)
+        Covariate data.
+
+    Returns
+    -------
+    torch.Tensor
+        Covariate data with useless intercepts removed.
+    """
     covariates = _format_data(covariates)
     if covariates.shape[1] < 2:
         return covariates
@@ -377,7 +672,26 @@ def _remove_useless_intercepts(covariates):
     return covariates
 
 
-def _check_data_shape(counts, covariates, offsets):
+def _check_data_shape(
+    counts: torch.Tensor, covariates: torch.Tensor, offsets: torch.Tensor
+) -> None:
+    """
+    Check the shape of the input data.
+
+    Parameters
+    ----------
+    counts : torch.Tensor, shape (n, p)
+        Count data.
+    covariates : torch.Tensor or None, shape (n, d) or None
+        Covariate data.
+    offsets : torch.Tensor or None, shape (n, p) or None
+        Offset data.
+
+    Raises
+    ------
+    ValueError
+        If the dimensions of the input data do not match.
+    """
     n_counts, p_counts = counts.shape
     n_offsets, p_offsets = offsets.shape
     _check_two_dimensions_are_equal("counts", "offsets", n_counts, n_offsets, 0)
@@ -387,7 +701,20 @@ def _check_data_shape(counts, covariates, offsets):
     _check_two_dimensions_are_equal("counts", "offsets", p_counts, p_offsets, 1)
 
 
-def _nice_string_of_dict(dictionnary):
+def _nice_string_of_dict(dictionnary: dict) -> str:
+    """
+    Create a nicely formatted string representation of a dictionary.
+
+    Parameters
+    ----------
+    dictionnary : dict
+        Dictionary to format.
+
+    Returns
+    -------
+    str
+        Nicely formatted string representation of the dictionary.
+    """
     return_string = ""
     for each_row in zip(*([i] + [j] for i, j in dictionnary.items())):
         for element in list(each_row):
@@ -396,7 +723,26 @@ def _nice_string_of_dict(dictionnary):
     return return_string
 
 
-def _plot_ellipse(mean_x, mean_y, cov, ax):
+def _plot_ellipse(mean_x: float, mean_y: float, cov: np.ndarray, ax) -> float:
+    """
+    Plot an ellipse on the given axes.
+
+    Parameters:
+    -----------
+    mean_x : float
+        Mean value of x-coordinate.
+    mean_y : float
+        Mean value of y-coordinate.
+    cov : np.ndarray
+        Covariance matrix.
+    ax : object
+        Axes object to plot the ellipse on.
+
+    Returns:
+    --------
+    float
+        Pearson correlation coefficient.
+    """
     pearson = cov[0, 1] / np.sqrt(cov[0, 0] * cov[1, 1])
     ell_radius_x = np.sqrt(1 + pearson)
     ell_radius_y = np.sqrt(1 - pearson)
@@ -421,7 +767,22 @@ def _plot_ellipse(mean_x, mean_y, cov, ax):
     return pearson
 
 
-def _get_components_simulation(dim, rank):
+def _get_components_simulation(dim: int, rank: int) -> torch.Tensor:
+    """
+    Get the components for simulation.
+
+    Parameters:
+    -----------
+    dim : int
+        Dimension.
+    rank : int
+        Rank.
+
+    Returns:
+    --------
+    torch.Tensor
+        Components for simulation.
+    """
     block_size = dim // rank
     prev_state = torch.random.get_rng_state()
     torch.random.manual_seed(0)
@@ -435,7 +796,26 @@ def _get_components_simulation(dim, rank):
     return components.to(DEVICE)
 
 
-def get_simulation_offsets_cov_coef(n_samples, nb_cov, dim):
+def get_simulation_offsets_cov_coef(
+    n_samples: int, nb_cov: int, dim: int
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+    """
+    Get simulation offsets, covariance coefficients.
+
+    Parameters:
+    -----------
+    n_samples : int
+        Number of samples.
+    nb_cov : int
+        Number of covariates.
+    dim : int
+        Dimension.
+
+    Returns:
+    --------
+    tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+        Tuple containing offsets, covariates, and coefficients.
+    """
     prev_state = torch.random.get_rng_state()
     torch.random.manual_seed(0)
     if nb_cov == 0:
@@ -457,8 +837,36 @@ def get_simulation_offsets_cov_coef(n_samples, nb_cov, dim):
 
 
 def get_simulated_count_data(
-    n_samples=100, dim=25, rank=5, nb_cov=1, return_true_param=False, seed=0
-):
+    n_samples: int = 100,
+    dim: int = 25,
+    rank: int = 5,
+    nb_cov: int = 1,
+    return_true_param: bool = False,
+    seed: int = 0,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+    """
+    Get simulated count data.
+
+    Parameters:
+    -----------
+    n_samples : int, optional
+        Number of samples, by default 100.
+    dim : int, optional
+        Dimension, by default 25.
+    rank : int, optional
+        Rank, by default 5.
+    nb_cov : int, optional
+        Number of covariates, by default 1.
+    return_true_param : bool, optional
+        Whether to return true parameters, by default False.
+    seed : int, optional
+        Seed value for random number generation, by default 0.
+
+    Returns:
+    --------
+    tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+        Tuple containing counts, covariates, and offsets.
+    """
     components = _get_components_simulation(dim, rank)
     offsets, cov, true_coef = get_simulation_offsets_cov_coef(n_samples, nb_cov, dim)
     true_covariance = torch.matmul(components, components.T)
@@ -468,7 +876,22 @@ def get_simulated_count_data(
     return counts, cov, offsets
 
 
-def get_real_count_data(n_samples=270, dim=100):
+def get_real_count_data(n_samples: int = 270, dim: int = 100) -> np.ndarray:
+    """
+    Get real count data.
+
+    Parameters:
+    -----------
+    n_samples : int, optional
+        Number of samples, by default 270.
+    dim : int, optional
+        Dimension, by default 100.
+
+    Returns:
+    --------
+    np.ndarray
+        Real count data.
+    """
     if n_samples > 297:
         warnings.warn(
             f"\nTaking the whole 270 samples of the dataset. Requested:n_samples={n_samples}, returned:270"
@@ -486,114 +909,102 @@ def get_real_count_data(n_samples=270, dim=100):
     return counts
 
 
-def _closest(lst, element):
+def _closest(lst: list[float], element: float) -> float:
+    """
+    Find the closest element in a list to a given element.
+
+    Parameters:
+    -----------
+    lst : list[float]
+        List of float values.
+    element : float
+        Element to find the closest value to.
+
+    Returns:
+    --------
+    float
+        Closest element in the list.
+    """
     lst = np.asarray(lst)
     idx = (np.abs(lst - element)).argmin()
     return lst[idx]
 
 
-class _PoissonReg:
-    """Poisson regressor class."""
-
-    def __init__(self):
-        """No particular initialization is needed."""
-        pass
-
-    def fit(self, Y, covariates, O, Niter_max=300, tol=0.001, lr=0.005, verbose=False):
-        """Run a gradient ascent to maximize the log likelihood, using
-        pytorch autodifferentiation. The log likelihood considered is
-        the one from a poisson regression model. It is roughly the
-        same as PLN without the latent layer Z.
-
-        Args:
-                        Y: torch.tensor. Counts with size (n,p)
-            0: torch.tensor. Offset, size (n,p)
-            covariates: torch.tensor. Covariates, size (n,d)
-            Niter_max: int, optional. The maximum number of iteration.
-                Default is 300.
-            tol: non negative float, optional. The tolerance criteria.
-                Will stop if the norm of the gradient is less than
-                or equal to this threshold. Default is 0.001.
-            lr: positive float, optional. Learning rate for the gradient ascent.
-                Default is 0.005.
-            verbose: bool, optional. If True, will print some stats.
-
-        Returns : None. Update the parameter beta. You can access it
-                by calling self.beta.
-        """
-        # Initialization of beta of size (d,p)
-        beta = torch.rand(
-            (covariates.shape[1], Y.shape[1]), device=DEVICE, requires_grad=True
-        )
-        optimizer = torch.optim.Rprop([beta], lr=lr)
-        i = 0
-        grad_norm = 2 * tol  # Criterion
-        while i < Niter_max and grad_norm > tol:
-            loss = -_compute_poissreg_log_like(Y, O, covariates, beta)
-            loss.backward()
-            optimizer.step()
-            grad_norm = torch.norm(beta.grad)
-            beta.grad.zero_()
-            i += 1
-            if verbose:
-                if i % 10 == 0:
-                    print("log like : ", -loss)
-                    print("grad_norm : ", grad_norm)
-                if i < Niter_max:
-                    print("Tolerance reached in {} iterations".format(i))
-                else:
-                    print("Maxium number of iterations reached")
-        self.beta = beta
-
-
-def _compute_poissreg_log_like(Y, O, covariates, beta):
-    """Compute the log likelihood of a Poisson regression."""
-    # Matrix multiplication of X and beta.
-    XB = torch.matmul(covariates.unsqueeze(1), beta.unsqueeze(0)).squeeze()
-    # Returns the formula of the log likelihood of a poisson regression model.
-    return torch.sum(-torch.exp(O + XB) + torch.multiply(Y, O + XB))
-
-
-def _to_tensor(obj):
-    if isinstance(obj, np.ndarray):
-        return torch.from_numpy(obj)
-    if isinstance(obj, torch.Tensor):
-        return obj
-    if isinstance(obj, pd.DataFrame):
-        return torch.from_numpy(obj.values)
-    raise TypeError("Please give either a nd.array or torch.Tensor or pd.DataFrame")
-
+def load_model(path_of_directory: str) -> Dict[str, Any]:
+    """
+    Load models from the given directory.
 
-def _check_dimensions_are_equal(tens1, tens2):
-    if tens1.shape[0] != tens2.shape[0] or tens1.shape[1] != tens2.shape[1]:
-        raise ValueError("Tensors should have the same size.")
+    Parameters
+    ----------
+    path_of_directory : str
+        The path to the directory containing the models.
 
+    Returns
+    -------
+    Dict[str, Any]
+        A dictionary containing the loaded models.
 
-def load_model(path_of_directory):
-    working_dict = os.getcwd()
+    """
+    working_dir = os.getcwd()
     os.chdir(path_of_directory)
     all_files = os.listdir()
     data = {}
     for filename in all_files:
-        if len(filename) > 4:
-            if filename[-4:] == ".csv":
-                parameter = filename[:-4]
-                try:
-                    data[parameter] = pd.read_csv(filename, header=None).values
-                except pd.errors.EmptyDataError as err:
-                    print(
-                        f"Can't load {parameter} since empty. Standard initialization will be performed"
-                    )
-    os.chdir(working_dict)
+        if filename.endswith(".csv"):
+            parameter = filename[:-4]
+            try:
+                data[parameter] = pd.read_csv(filename, header=None).values
+            except pd.errors.EmptyDataError as err:
+                print(
+                    f"Can't load {parameter} since empty. Standard initialization will be performed"
+                )
+    os.chdir(working_dir)
     return data
 
 
-def load_pln(path_of_directory):
+def load_pln(path_of_directory: str) -> Dict[str, Any]:
+    """
+    Load PLN models from the given directory.
+
+    Parameters
+    ----------
+    path_of_directory : str
+        The path to the directory containing the PLN models.
+
+    Returns
+    -------
+    Dict[str, Any]
+        A dictionary containing the loaded PLN models.
+
+    """
     return load_model(path_of_directory)
 
 
-def load_plnpca(path_of_directory, ranks=None):
-    working_dict = os.getcwd()
+def load_plnpca(
+    path_of_directory: str, ranks: Optional[list[int]] = None
+) -> Dict[int, Dict[str, Any]]:
+    """
+    Load PLNPCA models from the given directory.
+
+    Parameters
+    ----------
+    path_of_directory : str
+        The path to the directory containing the PLNPCA models.
+    ranks : list[int], optional
+        A list of ranks specifying which models to load. If None, all models in the directory will be loaded.
+
+    Returns
+    -------
+    Dict[int, Dict[str, Any]]
+        A dictionary containing the loaded PLNPCA models, with ranks as keys.
+
+    Raises
+    ------
+    ValueError
+        If an invalid model name is encountered and the rank cannot be determined.
+
+    """
+    working_dir = os.getcwd()
     os.chdir(path_of_directory)
     if ranks is None:
         dirnames = os.listdir()
@@ -609,20 +1020,51 @@ def load_plnpca(path_of_directory, ranks=None):
     datas = {}
     for rank in ranks:
         datas[rank] = load_model(f"_PLNPCA_rank_{rank}")
-    os.chdir(working_dict)
+    os.chdir(working_dir)
     return datas
 
 
-def _check_right_rank(data, rank):
+def _check_right_rank(data: Dict[str, Any], rank: int) -> None:
+    """
+    Check if the rank of the given data matches the specified rank.
+
+    Parameters
+    ----------
+    data : Dict[str, Any]
+        A dictionary containing the data.
+    rank : int
+        The expected rank.
+
+    Raises
+    ------
+    RuntimeError
+        If the rank of the data does not match the specified rank.
+
+    """
     data_rank = data["latent_mean"].shape[1]
     if data_rank != rank:
         raise RuntimeError(
-            f"Wrong rank during initialization."
-            f" Got rank {rank} and data with rank {data_rank}."
+            f"Wrong rank during initialization. Got rank {rank} and data with rank {data_rank}."
         )
 
 
-def _extract_data_from_formula(formula, data):
+def _extract_data_from_formula(formula: str, data: Dict[str, Any]) -> tuple:
+    """
+    Extract data from the given formula and data dictionary.
+
+    Parameters
+    ----------
+    formula : str
+        The formula specifying the data to extract.
+    data : Dict[str, Any]
+        A dictionary containing the data.
+
+    Returns
+    -------
+    tuple
+        A tuple containing the extracted counts, covariates, and offsets.
+
+    """
     dmatrix = dmatrices(formula, data=data)
     counts = dmatrix[0]
     covariates = dmatrix[1]
@@ -632,13 +1074,185 @@ def _extract_data_from_formula(formula, data):
     return counts, covariates, offsets
 
 
-def _is_dict_of_dict(dictionnary):
-    if isinstance(dictionnary[list(dictionnary.keys())[0]], dict):
-        return True
-    return False
+def _is_dict_of_dict(dictionary: Dict[Any, Any]) -> bool:
+    """
+    Check if the given dictionary is a dictionary of dictionaries.
+
+    Parameters
+    ----------
+    dictionary : Dict[Any, Any]
+        The dictionary to check.
+
+    Returns
+    -------
+    bool
+        True if the dictionary is a dictionary of dictionaries, False otherwise.
 
+    """
+    return isinstance(dictionary[list(dictionary.keys())[0]], dict)
+
+
+def _get_dict_initialization(
+    rank: int, dict_of_dict: Optional[Dict[int, Dict[str, Any]]]
+) -> Optional[Dict[str, Any]]:
+    """
+    Get the initialization dictionary for the given rank.
 
-def _get_dict_initalization(rank, dict_of_dict):
+    Parameters
+    ----------
+    rank : int
+        The rank to get the initialization dictionary for.
+    dict_of_dict : Dict[int, Dict[str, Any]], optional
+        A dictionary containing initialization dictionaries for different ranks.
+
+    Returns
+    -------
+    Optional[Dict[str, Any]]
+        The initialization dictionary for the given rank, or None if it does not exist.
+
+    """
     if dict_of_dict is None:
         return None
-    return dict_of_dict[rank]
+    return dict_of_dict.get(rank)
+
+
+def _to_tensor(
+    obj: Union[np.ndarray, torch.Tensor, pd.DataFrame, None]
+) -> Union[torch.Tensor, None]:
+    """
+    Convert an object to a PyTorch tensor.
+
+    Parameters:
+    ----------
+        obj (np.ndarray or torch.Tensor or pd.DataFrame or None):
+            The object to be converted.
+
+    Returns:
+        torch.Tensor or None:
+            The converted PyTorch tensor.
+
+    Raises:
+    ------
+        TypeError:
+            If the input object is not an np.ndarray, torch.Tensor, pd.DataFrame, or None.
+    """
+    if obj is None:
+        return None
+    if isinstance(obj, np.ndarray):
+        return torch.from_numpy(obj)
+    if isinstance(obj, torch.Tensor):
+        return obj
+    if isinstance(obj, pd.DataFrame):
+        return torch.from_numpy(obj.values)
+    raise TypeError(
+        "Please give either an np.ndarray or torch.Tensor or pd.DataFrame or None"
+    )
+
+
+class _PoissonReg:
+    """
+    Poisson regression model.
+
+    Attributes
+    ----------
+    beta : torch.Tensor
+        The learned regression coefficients.
+
+    Methods
+    -------
+    fit(Y, covariates, O, Niter_max=300, tol=0.001, lr=0.005, verbose=False)
+        Fit the Poisson regression model to the given data.
+
+    """
+
+    def __init__(self) -> None:
+        self.beta: Optional[torch.Tensor] = None
+
+    def fit(
+        self,
+        Y: torch.Tensor,
+        covariates: torch.Tensor,
+        O: torch.Tensor,
+        Niter_max: int = 300,
+        tol: float = 0.001,
+        lr: float = 0.005,
+        verbose: bool = False,
+    ) -> None:
+        """
+        Fit the Poisson regression model to the given data.
+
+        Parameters
+        ----------
+        Y : torch.Tensor
+            The dependent variable of shape (n_samples, n_features).
+        covariates : torch.Tensor
+            The covariates of shape (n_samples, n_covariates).
+        O : torch.Tensor
+            The offset term of shape (n_samples, n_features).
+        Niter_max : int, optional
+            The maximum number of iterations (default is 300).
+        tol : float, optional
+            The tolerance for convergence (default is 0.001).
+        lr : float, optional
+            The learning rate (default is 0.005).
+        verbose : bool, optional
+            Whether to print intermediate information during fitting (default is False).
+
+        """
+        beta = torch.rand(
+            (covariates.shape[1], Y.shape[1]), device=DEVICE, requires_grad=True
+        )
+        optimizer = torch.optim.Rprop([beta], lr=lr)
+        i = 0
+        grad_norm = 2 * tol  # Criterion
+        while i < Niter_max and grad_norm > tol:
+            loss = -compute_poissreg_log_like(Y, O, covariates, beta)
+            loss.backward()
+            optimizer.step()
+            grad_norm = torch.norm(beta.grad)
+            beta.grad.zero_()
+            i += 1
+            if verbose:
+                if i % 10 == 0:
+                    print("log like : ", -loss)
+                    print("grad_norm : ", grad_norm)
+                if i < Niter_max:
+                    print("Tolerance reached in {} iterations".format(i))
+                else:
+                    print("Maximum number of iterations reached")
+        self.beta = beta
+
+
+def compute_poissreg_log_like(
+    Y: torch.Tensor, O: torch.Tensor, covariates: torch.Tensor, beta: torch.Tensor
+) -> torch.Tensor:
+    """
+    Compute the log likelihood of a Poisson regression model.
+
+    Parameters
+    ----------
+    Y : torch.Tensor
+        The dependent variable of shape (n_samples, n_features).
+    O : torch.Tensor
+        The offset term of shape (n_samples, n_features).
+    covariates : torch.Tensor
+        The covariates of shape (n_samples, n_covariates).
+    beta : torch.Tensor
+        The regression coefficients of shape (n_covariates, n_features).
+
+    Returns
+    -------
+    torch.Tensor
+        The log likelihood of the Poisson regression model.
+
+    """
+    XB = torch.matmul(covariates.unsqueeze(1), beta.unsqueeze(0)).squeeze()
+    return torch.sum(-torch.exp(O + XB) + torch.multiply(Y, O + XB))
+
+
+def array2tensor(func):
+    def setter(self, array_like):
+        array_like = _to_tensor(array_like)
+        func(self, array_like)
+
+    return setter
diff --git a/pyPLNmodels/elbos.py b/pyPLNmodels/elbos.py
index 08d7a9818b2272ee89b56c8e565cdde0165b1bf8..c5f8927cd355ab2c9cb9030266cf4dbce49b4382 100644
--- a/pyPLNmodels/elbos.py
+++ b/pyPLNmodels/elbos.py
@@ -3,27 +3,48 @@ from ._utils import _log_stirling, _trunc_log
 from ._closed_forms import _closed_formula_covariance, _closed_formula_coef
 
 
-def elbo_pln(counts, covariates, offsets, latent_mean, latent_var, covariance, coef):
+from typing import Optional
+
+
+def elbo_pln(
+    counts: torch.Tensor,
+    offsets: torch.Tensor,
+    covariates: Optional[torch.Tensor],
+    latent_mean: torch.Tensor,
+    latent_var: torch.Tensor,
+    covariance: torch.Tensor,
+    coef: torch.Tensor,
+) -> torch.Tensor:
     """
-    Compute the ELBO (Evidence LOwer Bound) for the PLN model. See the doc for more details
-    on the computation.
+    Compute the ELBO (Evidence Lower Bound) for the PLN model.
+
+    Parameters:
+    ----------
+    counts : torch.Tensor
+        Counts with size (n, p).
+    offsets : torch.Tensor
+        Offset with size (n, p).
+    covariates : torch.Tensor, optional
+        Covariates with size (n, d).
+    latent_mean : torch.Tensor
+        Variational parameter with size (n, p).
+    latent_var : torch.Tensor
+        Variational parameter with size (n, p).
+    covariance : torch.Tensor
+        Model parameter with size (p, p).
+    coef : torch.Tensor
+        Model parameter with size (d, p).
 
-    Args:
-        counts: torch.tensor. Counts with size (n,p)
-        0: torch.tensor. Offset, size (n,p)
-        covariates: torch.tensor. Covariates, size (n,d)
-        latent_mean: torch.tensor. Variational parameter with size (n,p)
-        latent_var: torch.tensor. Variational parameter with size (n,p)
-        covariance: torch.tensor. Model parameter with size (p,p)
-        coef: torch.tensor. Model parameter with size (d,p)
     Returns:
-        torch.tensor of size 1 with a gradient.
+    -------
+    torch.Tensor
+        The ELBO (Evidence Lower Bound) with size 1, with a gradient.
     """
     n_samples, dim = counts.shape
     s_rond_s = torch.square(latent_var)
     offsets_plus_m = offsets + latent_mean
     if covariates is None:
-        XB = 0
+        XB = torch.zeros_like(counts)
     else:
         XB = covariates @ coef
     m_minus_xb = latent_mean - XB
@@ -42,55 +63,88 @@ def elbo_pln(counts, covariates, offsets, latent_mean, latent_var, covariance, c
     return elbo / n_samples
 
 
-def profiled_elbo_pln(counts, covariates, offsets, latent_mean, latent_var):
+import torch
+from typing import Optional
+
+
+def profiled_elbo_pln(
+    counts: torch.Tensor,
+    covariates: torch.Tensor,
+    offsets: torch.Tensor,
+    latent_mean: torch.Tensor,
+    latent_var: torch.Tensor,
+) -> torch.Tensor:
     """
-    Compute the ELBO (Evidence LOwer Bound) for the PLN model. We use the fact that covariance and coef are
-    completely determined by latent_mean,latent_var, and the covariates. See the doc for more details
-    on the computation.
+    Compute the ELBO (Evidence Lower Bound) for the PLN model with profiled parameters.
+
+    Parameters:
+    ----------
+    counts : torch.Tensor
+        Counts with size (n, p).
+    covariates : torch.Tensor
+        Covariates with size (n, d).
+    offsets : torch.Tensor
+        Offset with size (n, p).
+    latent_mean : torch.Tensor
+        Variational parameter with size (n, p).
+    latent_var : torch.Tensor
+        Variational parameter with size (n, p).
 
-    Args:
-        counts: torch.tensor. Counts with size (n,p)
-        0: torch.tensor. Offset, size (n,p)
-        covariates: torch.tensor. Covariates, size (n,d)
-        latent_mean: torch.tensor. Variational parameter with size (n,p)
-        latent_var: torch.tensor. Variational parameter with size (n,p)
-        covariance: torch.tensor. Model parameter with size (p,p)
-        coef: torch.tensor. Model parameter with size (d,p)
     Returns:
-        torch.tensor of size 1 with a gradient.
+    -------
+    torch.Tensor
+        The ELBO (Evidence Lower Bound) with size 1, with a gradient.
     """
     n_samples, _ = counts.shape
-    s_rond_s = torch.square(latent_var)
-    offsets_plus_m = offsets + latent_mean
+    s_squared = torch.square(latent_var)
+    offsets_plus_mean = offsets + latent_mean
     closed_coef = _closed_formula_coef(covariates, latent_mean)
     closed_covariance = _closed_formula_covariance(
         covariates, latent_mean, latent_var, closed_coef, n_samples
     )
     elbo = -0.5 * n_samples * torch.logdet(closed_covariance)
     elbo += torch.sum(
-        counts * offsets_plus_m
-        - torch.exp(offsets_plus_m + s_rond_s / 2)
-        + 0.5 * torch.log(s_rond_s)
+        counts * offsets_plus_mean
+        - torch.exp(offsets_plus_mean + s_squared / 2)
+        + 0.5 * torch.log(s_squared)
     )
     elbo -= torch.sum(_log_stirling(counts))
     return elbo / n_samples
 
 
-def elbo_plnpca(counts, covariates, offsets, latent_mean, latent_var, components, coef):
+def elbo_plnpca(
+    counts: torch.Tensor,
+    covariates: torch.Tensor,
+    offsets: torch.Tensor,
+    latent_mean: torch.Tensor,
+    latent_var: torch.Tensor,
+    components: torch.Tensor,
+    coef: torch.Tensor,
+) -> torch.Tensor:
     """
-    Compute the ELBO (Evidence LOwer Bound) for the PLN model with a PCA
-    parametrization. See the doc for more details on the computation.
+    Compute the ELBO (Evidence Lower Bound) for the PLN model with PCA parametrization.
+
+    Parameters:
+    ----------
+    counts : torch.Tensor
+        Counts with size (n, p).
+    covariates : torch.Tensor
+        Covariates with size (n, d).
+    offsets : torch.Tensor
+        Offset with size (n, p).
+    latent_mean : torch.Tensor
+        Variational parameter with size (n, p).
+    latent_var : torch.Tensor
+        Variational parameter with size (n, p).
+    components : torch.Tensor
+        Model parameter with size (p, q).
+    coef : torch.Tensor
+        Model parameter with size (d, p).
 
-    Args:
-        counts: torch.tensor. Counts with size (n,p)
-        0: torch.tensor. Offset, size (n,p)
-        covariates: torch.tensor. Covariates, size (n,d)
-        latent_mean: torch.tensor. Variational parameter with size (n,p)
-        latent_var: torch.tensor. Variational parameter with size (n,p)
-        components: torch.tensor. Model parameter with size (p,q)
-        coef: torch.tensor. Model parameter with size (d,p)
     Returns:
-        torch.tensor of size 1 with a gradient.
+    -------
+    torch.Tensor
+        The ELBO (Evidence Lower Bound) with size 1, with a gradient.
     """
     n_samples = counts.shape[0]
     rank = components.shape[1]
@@ -99,22 +153,22 @@ def elbo_plnpca(counts, covariates, offsets, latent_mean, latent_var, components
     else:
         XB = covariates @ coef
     log_intensity = offsets + XB + latent_mean @ components.T
-    s_rond_s = torch.square(latent_var)
+    s_squared = torch.square(latent_var)
     counts_log_intensity = torch.sum(counts * log_intensity)
-    minus_intensity_plus_s_rond_s_cct = torch.sum(
-        -torch.exp(log_intensity + 0.5 * s_rond_s @ (components * components).T)
+    minus_intensity_plus_s_squared_cct = torch.sum(
+        -torch.exp(log_intensity + 0.5 * s_squared @ (components * components).T)
     )
-    minuslogs_rond_s = 0.5 * torch.sum(torch.log(s_rond_s))
-    mm_plus_s_rond_s = -0.5 * torch.sum(
+    minus_logs_squared = 0.5 * torch.sum(torch.log(s_squared))
+    mm_plus_s_squared = -0.5 * torch.sum(
         torch.square(latent_mean) + torch.square(latent_var)
     )
-    _log_stirlingcounts = torch.sum(_log_stirling(counts))
+    log_stirling_counts = torch.sum(_log_stirling(counts))
     return (
         counts_log_intensity
-        + minus_intensity_plus_s_rond_s_cct
-        + minuslogs_rond_s
-        + mm_plus_s_rond_s
-        - _log_stirlingcounts
+        + minus_intensity_plus_s_squared_cct
+        + minus_logs_squared
+        + mm_plus_s_squared
+        - log_stirling_counts
         + 0.5 * n_samples * rank
     ) / n_samples
 
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 22f43d92588884c5bf28156dea4a96a4a08c1799..c3849fb9cce9eeda827cfa6a5f02ba759139c9b4 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -26,19 +26,17 @@ from ._utils import (
     _init_covariance,
     _init_components,
     _init_coef,
-    _check_two_dimensions_are_equal,
     _init_latent_mean,
     _format_data,
     _format_model_param,
-    _check_data_shape,
     _nice_string_of_dict,
     _plot_ellipse,
     _closest,
-    _to_tensor,
-    _check_dimensions_are_equal,
+    _check_data_shape,
     _check_right_rank,
     _extract_data_from_formula,
-    _get_dict_initalization,
+    _get_dict_initialization,
+    array2tensor,
 )
 
 if torch.cuda.is_available():
@@ -73,7 +71,6 @@ class _PLN(ABC):
     _latent_var: torch.Tensor
     _latent_mean: torch.Tensor
 
-    @singledispatchmethod
     def __init__(
         self,
         counts,
@@ -94,11 +91,11 @@ class _PLN(ABC):
         self._fitted = False
         self._plotargs = _PlotArgs(self._WINDOW)
         if dict_initialization is not None:
-            self._set__init_parameters(dict_initialization)
+            self._set_init_parameters(dict_initialization)
 
-    @__init__.register(str)
-    def _(
-        self,
+    @classmethod
+    def from_formula(
+        cls,
         formula: str,
         data: dict,
         offsets_formula="logsum",
@@ -106,7 +103,7 @@ class _PLN(ABC):
         take_log_offsets=False,
     ):
         counts, covariates, offsets = _extract_data_from_formula(formula, data)
-        self.__init__(
+        return cls(
             counts,
             covariates,
             offsets,
@@ -115,7 +112,7 @@ class _PLN(ABC):
             take_log_offsets,
         )
 
-    def _set__init_parameters(self, dict_initialization):
+    def _set_init_parameters(self, dict_initialization):
         if "coef" not in dict_initialization.keys():
             print("No coef is initialized.")
             self.coef = None
@@ -177,7 +174,6 @@ class _PLN(ABC):
             self._random_init_model_parameters()
             self._random_init_latent_parameters()
         print("Initialization finished")
-        self._put_parameters_to_device()
 
     def _put_parameters_to_device(self):
         for parameter in self._list_of_parameters_needing_gradient:
@@ -218,6 +214,7 @@ class _PLN(ABC):
             self._init_parameters(do_smart_init)
         else:
             self._beginning_time -= self._plotargs.running_times[-1]
+        self._put_parameters_to_device()
         self.optim = class_optimizer(self._list_of_parameters_needing_gradient, lr=lr)
         stop_condition = False
         while self.nb_iteration_done < nb_max_iteration and stop_condition == False:
@@ -360,7 +357,7 @@ class _PLN(ABC):
             _, axes = plt.subplots(1, nb_axes, figsize=(23, 5))
         if self._fitted is True:
             self._plotargs._show_loss(ax=axes[2])
-            self._plotargs._show_stopping_criteration(ax=axes[1])
+            self._plotargs._show_stopping_criterion(ax=axes[1])
             self.display_covariance(ax=axes[0])
         else:
             self.display_covariance(ax=axes)
@@ -412,25 +409,35 @@ class _PLN(ABC):
 
     @property
     def coef(self):
-        return self._attribute_or_none("_coef")
+        return self._cpu_attribute_or_none("_coef")
 
     @property
     def latent_mean(self):
-        return self._attribute_or_none("_latent_mean")
+        return self._cpu_attribute_or_none("_latent_mean")
 
     @property
     def latent_var(self):
-        return self._attribute_or_none("_latent_var")
-
-    @latent_var.setter
-    def latent_var(self, latent_var):
-        self._latent_var = latent_var
+        return self._cpu_attribute_or_none("_latent_var")
 
     @latent_mean.setter
+    @array2tensor
     def latent_mean(self, latent_mean):
+        if latent_mean.shape != (self.n_samples, self.dim):
+            raise ValueError(
+                f"Wrong shape. Expected {self.n_samples, self.dim}, got {latent_mean.shape}"
+            )
         self._latent_mean = latent_mean
 
-    def _attribute_or_none(self, attribute_name):
+    @latent_var.setter
+    @array2tensor
+    def latent_var(self, latent_var):
+        if latent_var.shape != (self.n_samples, self.dim):
+            raise ValueError(
+                f"Wrong shape. Expected {self.n_samples, self.dim}, got {latent_var.shape}"
+            )
+        self._latent_var = latent_var
+
+    def _cpu_attribute_or_none(self, attribute_name):
         if hasattr(self, attribute_name):
             attr = getattr(self, attribute_name)
             if isinstance(attr, torch.Tensor):
@@ -454,33 +461,51 @@ class _PLN(ABC):
 
     @property
     def counts(self):
-        return self._attribute_or_none("_counts")
+        return self._cpu_attribute_or_none("_counts")
 
     @property
     def offsets(self):
-        return self._attribute_or_none("_offsets")
+        return self._cpu_attribute_or_none("_offsets")
 
     @property
     def covariates(self):
-        return self._attribute_or_none("_covariates")
+        return self._cpu_attribute_or_none("_covariates")
 
     @counts.setter
+    @array2tensor
     def counts(self, counts):
-        counts = _to_tensor(counts)
-        if hasattr(self, "_counts"):
-            _check_dimensions_are_equal(self._counts, counts)
+        if self.counts.shape != counts.shape:
+            raise ValueError(
+                f"Wrong shape for the counts. Expected {self.counts.shape}, got {counts.shape}"
+            )
+        if torch.min(counts) < 0:
+            raise ValueError("Input should be integers only.")
         self._counts = counts
 
     @offsets.setter
+    @array2tensor
     def offsets(self, offsets):
+        if self.offsets.shape != offsets.shape:
+            raise ValueError(
+                f"Wrong shape for the offsets. Expected {self.offsets.shape}, got {offsets.shape}"
+            )
         self._offsets = offsets
 
     @covariates.setter
+    @array2tensor
     def covariates(self, covariates):
+        _check_data_shape(self.counts, covariates, self.offsets)
         self._covariates = covariates
 
     @coef.setter
+    @array2tensor
     def coef(self, coef):
+        if coef is None:
+            pass
+        elif coef.shape != (self.nb_cov, self.dim):
+            raise ValueError(
+                f"Wrong shape for the counts. Expected {(self.nb_cov, self.dim)}, got {coef.shape}"
+            )
         self._coef = coef
 
     @property
@@ -544,8 +569,12 @@ class PLN(_PLN):
 
     @property
     def coef(self):
-        if hasattr(self, "_latent_mean") and hasattr(self, "_covariates"):
-            return self._coef
+        if (
+            hasattr(self, "_latent_mean")
+            and hasattr(self, "_covariates")
+            and self.nb_cov > 0
+        ):
+            return self._coef.detach().cpu()
         return None
 
     @coef.setter
@@ -641,7 +670,6 @@ class PLN(_PLN):
 class PLNPCA:
     _NAME = "PLNPCA"
 
-    @singledispatchmethod
     def __init__(
         self,
         counts,
@@ -664,10 +692,10 @@ class PLNPCA:
         _check_data_shape(self._counts, self._covariates, self._offsets)
         self._fitted = False
 
-    @__init__.register(str)
-    def _(
-        self,
-        formula,
+    @classmethod
+    def from_formula(
+        cls,
+        formula: str,
         data: dict,
         offsets_formula="logsum",
         ranks=range(3, 5),
@@ -675,7 +703,7 @@ class PLNPCA:
         take_log_offsets=False,
     ):
         counts, covariates, offsets = _extract_data_from_formula(formula, data)
-        self.__init__(
+        return cls(
             counts,
             covariates,
             offsets,
@@ -693,38 +721,65 @@ class PLNPCA:
     def counts(self):
         return self.list_models[0].counts
 
+    @property
+    def coef(self):
+        return {model.rank: model.coef for model in self.list_models}
+
+    @property
+    def components(self):
+        return {model.rank: model.components for model in self.list_models}
+
+    @property
+    def latent_mean(self):
+        return {model.rank: model.latent_mean for model in self.list_models}
+
+    @property
+    def latent_var(self):
+        return {model.rank: model.latent_var for model in self.list_models}
+
     @counts.setter
+    @array2tensor
     def counts(self, counts):
-        counts = _format_data(counts)
-        if hasattr(self, "_counts"):
-            _check_dimensions_are_equal(self._counts, counts)
-        self._counts = counts
+        for model in self.list_models:
+            model.counts = counts
+
+    @coef.setter
+    @array2tensor
+    def coef(self, coef):
+        for model in self.list_models:
+            model.coef = coef
 
     @covariates.setter
+    @array2tensor
     def covariates(self, covariates):
-        covariates = _format_data(covariates)
-        # if hasattr(self,)
-        self._covariates = covariates
+        for model in self.list_models:
+            model.covariates = covariates
 
     @property
     def offsets(self):
         return self.list_models[0].offsets
 
+    @offsets.setter
+    @array2tensor
+    def offsets(self, offsets):
+        for model in self.list_models:
+            model.offsets = offsets
+
     def _init_models(self, ranks, dict_of_dict_initialization):
         if isinstance(ranks, (Iterable, np.ndarray)):
             self.list_models = []
             for rank in ranks:
                 if isinstance(rank, (int, np.integer)):
-                    dict_initialization = _get_dict_initalization(
+                    dict_initialization = _get_dict_initialization(
                         rank, dict_of_dict_initialization
                     )
                     self.list_models.append(
                         _PLNPCA(
-                            self._counts,
-                            self._covariates,
-                            self._offsets,
-                            rank,
-                            dict_initialization,
+                            counts=self._counts,
+                            covariates=self._covariates,
+                            offsets=self._offsets,
+                            rank=rank,
+                            dict_initialization=dict_initialization,
                         )
                     )
                 else:
@@ -733,7 +788,7 @@ class PLNPCA:
                         f"of integers or an integer."
                     )
         elif isinstance(ranks, (int, np.integer)):
-            dict_initialization = _get_dict_initalization(
+            dict_initialization = _get_dict_initialization(
                 ranks, dict_of_dict_initialization
             )
             self.list_models = [
@@ -741,7 +796,7 @@ class PLNPCA:
                     self._counts,
                     self._covariates,
                     self._offsets,
-                    rank,
+                    ranks,
                     dict_initialization,
                 )
             ]
@@ -781,8 +836,9 @@ class PLNPCA:
         verbose=False,
     ):
         self._pring_beginning_message()
-        for pca in self.dict_models.values():
-            pca.fit(
+        for i in range(len(self.list_models)):
+            model = self.list_models[i]
+            model.fit(
                 nb_max_iteration,
                 lr,
                 class_optimizer,
@@ -790,8 +846,17 @@ class PLNPCA:
                 do_smart_init,
                 verbose,
             )
+            if i < len(self.list_models) - 1:
+                next_model = self.list_models[i + 1]
+                self.init_next_model_with_previous_parameters(next_model, model)
         self._print_ending_message()
 
+    def init_next_model_with_previous_parameters(self, next_model, current_model):
+        next_model.coef = current_model.coef
+        next_model.components = torch.zeros(self.dim, next_model.rank)
+        with torch.no_grad():
+            next_model._components[:, : current_model.rank] = current_model._components
+
     def _print_ending_message(self):
         delimiter = "=" * NB_CHARACTERS_FOR_NICE_PLOT
         print(f"{delimiter}\n")
@@ -926,8 +991,15 @@ class _PLNPCA(_PLN):
     _NAME = "_PLNPCA"
     _components: torch.Tensor
 
-    @singledispatchmethod
-    def __init__(self, counts, covariates, offsets, rank, dict_initialization=None):
+    def __init__(
+        self,
+        counts,
+        covariates=None,
+        offsets=None,
+        offsets_formula="logsum",
+        rank=5,
+        dict_initialization=None,
+    ):
         self._rank = rank
         self._counts, self._covariates, self._offsets = _format_model_param(
             counts, covariates, offsets, None, take_log_offsets=False
@@ -935,14 +1007,18 @@ class _PLNPCA(_PLN):
         _check_data_shape(self._counts, self._covariates, self._offsets)
         self._check_if_rank_is_too_high()
         if dict_initialization is not None:
-            self._set__init_parameters(dict_initialization)
+            self._set_init_parameters(dict_initialization)
         self._fitted = False
         self._plotargs = _PlotArgs(self._WINDOW)
 
-    @__init__.register(str)
-    def _(self, formula, data, rank, dict_initialization):
+    @classmethod
+    def from_formula(
+        cls, formula, data, rank=5, offsets_formula="logsum", dict_initialization=None
+    ):
         counts, covariates, offsets = _extract_data_from_formula(formula, data)
-        self.__init__(counts, covariates, offsets, rank, dict_initialization)
+        return cls(
+            counts, covariates, offsets, offsets_formula, rank, dict_initialization
+        )
 
     def _check_if_rank_is_too_high(self):
         if self.dim < self.rank:
@@ -954,11 +1030,49 @@ class _PLNPCA(_PLN):
             warnings.warn(warning_string)
             self._rank = self.dim
 
+    @property
+    def latent_mean(self):
+        return self._cpu_attribute_or_none("_latent_mean")
+
+    @property
+    def latent_var(self):
+        return self._cpu_attribute_or_none("_latent_var")
+
+    @latent_mean.setter
+    @array2tensor
+    def latent_mean(self, latent_mean):
+        if latent_mean.shape != (self.n_samples, self.rank):
+            raise ValueError(
+                f"Wrong shape. Expected {self.n_samples, self.rank}, got {latent_mean.shape}"
+            )
+        self._latent_mean = latent_mean
+
+    @latent_var.setter
+    @array2tensor
+    def latent_var(self, latent_var):
+        if latent_var.shape != (self.n_samples, self.rank):
+            raise ValueError(
+                f"Wrong shape. Expected {self.n_samples, self.rank}, got {latent_var.shape}"
+            )
+        self._latent_var = latent_var
+
     @property
     def directory_name(self):
         return f"{self._NAME}_rank_{self._rank}"
         # return f"PLNPCA_nbcov_{self.nb_cov}_dim_{self.dim}/{self._NAME}_rank_{self._rank}"
 
+    @property
+    def covariates(self):
+        return self._cpu_attribute_or_none("_covariates")
+
+    @covariates.setter
+    @array2tensor
+    def covariates(self, covariates):
+        _check_data_shape(self.counts, covariates, self.offsets)
+        self._covariates = covariates
+        print("Setting coef to initialization")
+        self._smart_init_coef()
+
     @property
     def path_to_directory(self):
         return f"PLNPCA_nbcov_{self.nb_cov}_dim_{self.dim}/"
@@ -1076,10 +1190,15 @@ class _PLNPCA(_PLN):
 
     @property
     def components(self):
-        return self._attribute_or_none("_components")
+        return self._cpu_attribute_or_none("_components")
 
     @components.setter
+    @array2tensor
     def components(self, components):
+        if components.shape != (self.dim, self.rank):
+            raise ValueError(
+                f"Wrong shape. Expected {self.dim, self.rank}, got {components.shape}"
+            )
         self._components = components
 
     def viz(self, ax=None, colors=None):
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..94b2dc2a7d819ecd920828f1fa82abe1552dffca
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,46 @@
+[build-system]
+requires = ["setuptools", "setuptools-scm"]
+build-backend = "setuptools.build_meta"
+
+[tool.setuptools_scm]
+
+
+[project]
+name = "pyPLNmodels"
+dynamic = ["version"]
+description = "Package implementing PLN models"
+readme = "README.md"
+license = {text = "MIT License"}
+requires-python = ">=3.7"
+keywords = [
+        "python",
+        "count",
+        "data",
+        "count data",
+        "high dimension",
+        "scRNAseq",
+        "PLN",
+        ]
+authors = [
+  {name = "Bastien Batardiere", email = "bastien.batardiere@gmail.com"},
+  {name = "Julien Chiquet", email = "julien.chiquet@inrae.fr"},
+  {name = "Joon Kwon", email = "joon.kwon@inrae.fr"},
+]
+maintainers = [{name = "Bastien Batardière", email = "bastien.batardiere@gmail.com"},
+  {name = "Julien Chiquet", email = "julien.chiquet@inrae.fr"},
+]
+classifiers = [
+        # How mature is this project? Common values are
+        #   3 - Alpha
+        #   4 - Beta
+        #   5 - Production/Stable
+        "Development Status :: 4 - Alpha",
+        # Indicate who your project is intended for
+        "Intended Audience :: Science/Research",
+        # Pick your license as you wish (should match "license" above)
+        "License :: OSI Approved :: MIT License",
+        # Specify the Python versions you support here. In particular, ensure
+        # that you indicate whether you support Python 2, Python 3 or both.
+        "Programming Language :: Python :: 3 :: Only",
+]
+dependencies = {file = ["requirements.txt"]}
diff --git a/tests/conftest.py b/tests/conftest.py
index 1fd52a44e29889465b7a832f536dab2c86e59d81..f4cdbbbda93ffee62871db3ec65cbe120e5ea013 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -55,51 +55,29 @@ instances = []
 # dict_fixtures_models = []
 
 
-@singledispatch
-def convenient_plnpca(
-    counts,
-    covariates=None,
-    offsets=None,
-    offsets_formula=None,
-    dict_initialization=None,
-):
-    return _PLNPCA(
-        counts, covariates, offsets, rank=RANK, dict_initialization=dict_initialization
-    )
+def convenient_plnpca(*args, **kwargs):
+    dict_init = kwargs.pop("dict_initialization", None)
+    if isinstance(args[0], str):
+        return _PLNPCA.from_formula(
+            *args, **kwargs, dict_initialization=dict_init, rank=RANK
+        )
+    print("rank:", RANK)
+    return _PLNPCA(*args, **kwargs, dict_initialization=dict_init, rank=RANK)
 
 
-@convenient_plnpca.register(str)
-def _(formula, data, offsets_formula=None, dict_initialization=None):
-    return _PLNPCA(formula, data, rank=RANK, dict_initialization=dict_initialization)
+def convenientplnpca(*args, **kwargs):
+    dict_init = kwargs.pop("dict_initialization", None)
+    if isinstance(args[0], str):
+        return PLNPCA.from_formula(
+            *args, **kwargs, dict_of_dict_initialization=dict_init, ranks=RANKS
+        )
+    return PLNPCA(*args, **kwargs, dict_of_dict_initialization=dict_init, ranks=RANKS)
 
 
-@singledispatch
-def convenientplnpca(
-    counts,
-    covariates=None,
-    offsets=None,
-    offsets_formula=None,
-    dict_initialization=None,
-):
-    return PLNPCA(
-        counts,
-        covariates,
-        offsets,
-        offsets_formula,
-        dict_of_dict_initialization=dict_initialization,
-        ranks=RANKS,
-    )
-
-
-@convenientplnpca.register(str)
-def _(formula, data, offsets_formula=None, dict_initialization=None):
-    return PLNPCA(
-        formula,
-        data,
-        offsets_formula,
-        ranks=RANKS,
-        dict_of_dict_initialization=dict_initialization,
-    )
+def convenientpln(*args, **kwargs):
+    if isinstance(args[0], str):
+        return PLN.from_formula(*args, **kwargs)
+    return PLN(*args, **kwargs)
 
 
 def generate_new_model(model, *args, **kwargs):
@@ -109,7 +87,7 @@ def generate_new_model(model, *args, **kwargs):
         path = model.path_to_directory + name_dir
         init = load_model(path)
         if name == "PLN":
-            new = PLN(*args, **kwargs, dict_initialization=init)
+            new = convenientpln(*args, **kwargs, dict_initialization=init)
         if name == "_PLNPCA":
             new = convenient_plnpca(*args, **kwargs, dict_initialization=init)
     if name == "PLNPCA":
@@ -129,7 +107,7 @@ def cache(func):
     return new_func
 
 
-params = [PLN, convenient_plnpca, convenientplnpca]
+params = [convenientpln, convenient_plnpca, convenientplnpca]
 dict_fixtures = {}
 
 
diff --git a/tests/test_common.py b/tests/test_common.py
index 42ae334919d9740e94b8481bd19c21269b2cba95..4f539f0d9e0fc4bad0066ada154e120cc1511e12 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -28,7 +28,7 @@ def test_print(any_pln):
 def test_show_coef_transform_covariance_pcaprojected(any_pln):
     any_pln.show()
     any_pln._plotargs._show_loss()
-    any_pln._plotargs._show_stopping_criteration()
+    any_pln._plotargs._show_stopping_criterion()
     assert hasattr(any_pln, "coef")
     assert callable(any_pln.transform)
     assert hasattr(any_pln, "covariance")
diff --git a/tests/test_setters.py b/tests/test_setters.py
index b1a9ba29c09d21faa2a10bc5dae71dbe368cbd7e..727c6ac697d253cf42d8957c4ff2541e926db718 100644
--- a/tests/test_setters.py
+++ b/tests/test_setters.py
@@ -1,21 +1,148 @@
 import pytest
 import pandas as pd
+import torch
 
 from tests.conftest import dict_fixtures
 from tests.utils import MSE, filter_models
 
 
 @pytest.mark.parametrize("pln", dict_fixtures["all_pln"])
-@filter_models(["PLN", "PLNPCA"])
-def test_setter_with_numpy(pln):
+def test_data_setter_with_torch(pln):
+    pln.counts = pln.counts
+    pln.covariates = pln.covariates
+    pln.offsets = pln.offsets
+    pln.fit()
+
+
+@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"])
+@filter_models(["PLN", "_PLNPCA"])
+def test_parameters_setter_with_torch(pln):
+    pln.latent_mean = pln.latent_mean
+    pln.latent_var = pln.latent_var
+    pln.coef = pln.coef
+    if pln._NAME == "_PLNPCA":
+        pln.components = pln.components
+    pln.fit()
+
+
+@pytest.mark.parametrize("pln", dict_fixtures["all_pln"])
+def test_data_setter_with_numpy(pln):
     np_counts = pln.counts.numpy()
+    if pln.covariates is not None:
+        np_covariates = pln.covariates.numpy()
+    else:
+        np_covariates = None
+    np_offsets = pln.offsets.numpy()
     pln.counts = np_counts
+    pln.covariates = np_covariates
+    pln.offsets = np_offsets
+    pln.fit()
+
+
+@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"])
+@filter_models(["PLN", "_PLNPCA"])
+def test_parameters_setter_with_numpy(pln):
+    np_latent_mean = pln.latent_mean.numpy()
+    np_latent_var = pln.latent_var.numpy()
+    if pln.coef is not None:
+        np_coef = pln.coef.numpy()
+    else:
+        np_coef = None
+    pln.latent_mean = np_latent_mean
+    pln.latent_var = np_latent_var
+    pln.coef = np_coef
+    if pln._NAME == "_PLNPCA":
+        pln.components = pln.components.numpy()
     pln.fit()
 
 
 @pytest.mark.parametrize("pln", dict_fixtures["all_pln"])
-@filter_models(["PLN", "PLNPCA"])
-def test_setter_with_pandas(pln):
+def test_data_setter_with_pandas(pln):
     pd_counts = pd.DataFrame(pln.counts.numpy())
+    if pln.covariates is not None:
+        pd_covariates = pd.DataFrame(pln.covariates.numpy())
+    else:
+        pd_covariates = None
+    pd_offsets = pd.DataFrame(pln.offsets.numpy())
     pln.counts = pd_counts
+    pln.covariates = pd_covariates
+    pln.offsets = pd_offsets
     pln.fit()
+
+
+@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"])
+@filter_models(["PLN", "_PLNPCA"])
+def test_parameters_setter_with_pandas(pln):
+    pd_latent_mean = pd.DataFrame(pln.latent_mean.numpy())
+    pd_latent_var = pd.DataFrame(pln.latent_var.numpy())
+    if pln.coef is not None:
+        pd_coef = pd.DataFrame(pln.coef.numpy())
+    else:
+        pd_coef = None
+    pln.latent_mean = pd_latent_mean
+    pln.latent_var = pd_latent_var
+    pln.coef = pd_coef
+    if pln._NAME == "_PLNPCA":
+        pln.components = pd.DataFrame(pln.components.numpy())
+    pln.fit()
+
+
+@pytest.mark.parametrize("pln", dict_fixtures["all_pln"])
+def test_fail_data_setter_with_torch(pln):
+    with pytest.raises(ValueError):
+        pln.counts = pln.counts - 100
+
+    n, p = pln.counts.shape
+    if pln.covariates is None:
+        d = 0
+    else:
+        d = pln.covariates.shape[-1]
+    with pytest.raises(ValueError):
+        pln.counts = torch.zeros(n + 1, p)
+    with pytest.raises(ValueError):
+        pln.counts = torch.zeros(n, p + 1)
+
+    with pytest.raises(ValueError):
+        pln.covariates = torch.zeros(n + 1, d)
+
+    with pytest.raises(ValueError):
+        pln.offsets = torch.zeros(n + 1, p)
+
+    with pytest.raises(ValueError):
+        pln.offsets = torch.zeros(n, p + 1)
+
+
+@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"])
+@filter_models(["PLN", "_PLNPCA"])
+def test_fail_parameters_setter_with_torch(pln):
+    n, dim_latent = pln.latent_mean.shape
+    dim = pln.counts.shape[1]
+
+    with pytest.raises(ValueError):
+        pln.latent_mean = torch.zeros(n + 1, dim_latent)
+
+    with pytest.raises(ValueError):
+        pln.latent_mean = torch.zeros(n, dim_latent + 1)
+
+    with pytest.raises(ValueError):
+        pln.latent_var = torch.zeros(n + 1, dim_latent)
+
+    with pytest.raises(ValueError):
+        pln.latent_var = torch.zeros(n, dim_latent + 1)
+
+    if pln._NAME == "_PLNPCA":
+        with pytest.raises(ValueError):
+            pln.components = torch.zeros(dim, dim_latent + 1)
+
+        with pytest.raises(ValueError):
+            pln.components = torch.zeros(dim + 1, dim_latent)
+
+        if pln.covariates is None:
+            d = 0
+        else:
+            d = pln.covariates.shape[-1]
+        with pytest.raises(ValueError):
+            pln.coef = torch.zeros(d + 1, dim)
+
+        with pytest.raises(ValueError):
+            pln.coef = torch.zeros(d, dim + 1)