from pymatting.util.util import trimap_split
import scipy.sparse
def make_linear_system(L, trimap, lambda_value=100.0, return_c=False):
"""This function constructs a linear system from a matting Laplacian by
constraining the foreground and background pixels with a diagonal matrix
`C` to values in the right-hand-side vector `b`. The constraints are
weighted by a factor :math:`\lambda`. The linear system is given as
.. math::
A = L + \lambda C,
where :math:`C=\mathop{Diag}(c)` having :math:`c_i = 1` if pixel i is known
and :math:`c_i = 0` otherwise.
The right-hand-side :math:`b` is a vector with entries :math:`b_i = 1` is
pixel is is a foreground pixel and :math:`b_i = 0` otherwise.
Parameters
----------
L: scipy.sparse.spmatrix
Laplacian matrix, e.g. calculated with :code:`lbdm_laplacian` function
trimap: numpy.ndarray
Trimap with shape :math:`h\\times w`
lambda_value: float
Constraint penalty, defaults to 100
return_c: bool
Whether to return the constraint matrix `C`, defaults to False
Returns
-------
A: scipy.sparse.spmatrix
Matrix describing the system of linear equations
b: numpy.ndarray
Vector describing the right-hand side of the system
C: numpy.ndarray
Vector describing the diagonal entries of the matrix `C`, only returned
if `return_c` is set to True
"""
h, w = trimap.shape[:2]
is_fg, is_bg, is_known, is_unknown = trimap_split(trimap)
c = lambda_value * is_known
b = lambda_value * is_fg
C = scipy.sparse.diags(c)
A = L + C
A = A.tocsr()
A.sum_duplicates()
if return_c:
return A, b, c
return A, b