# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Format module for sparse tensor algebra."""
from typing import Callable, Dict, List, Union
import tvm._ffi
import tvm.tir
from tvm.runtime import Object
from tvm.tir import IndexMap, _ffi_api
from tvm import IRModule
from tvm.tir.transform import SparseFormatDecompose
[docs]def column_part_hyb(num_rows, num_cols, indptr_nd, indices_nd, num_col_parts, buckets):
    """Partition input CSR matrix by columns and collect rows into buckets according to non zero elements per row.
    Parameters
    ----------
    num_rows : int
        Number of rows in the CSR matrix.
    num_cols : int
        Number of columns in the CSR matrix.
    indptr : NDArray
        The indptr array of CSR matrix.
    indices : NDArray
        The indices array of CSR matrix.
    num_col_parts : int
        Number of column partitions.
    buckets : List
        The bucket sizes array.
    Returns
    -------
    Tuple[List[List[NDArray]]]
        The pair of (row_indices, col_indices, mask).
        row_indices is stored as a list of lists with shape (num_col_parts, len(buckets)), where the innermost element is an NDArray.
        col_indices and mask are stored in the same way.
    """
    return _ffi_api.ColumnPartHyb(
        num_rows, num_cols, indptr_nd, indices_nd, num_col_parts, buckets  # type: ignore
    ) 
[docs]def condense(indptr_nd, indices_nd, t, g):
    """Condense sparse matrix in CSR format to (t x 1) tiles, and group g tiles together.
    Parameters
    ----------
    indptr : NDArray
        The indptr array of CSR format.
    indices : NDArray
        The indices array of CSR format.
    t : int
        The tile size.
    g : int
        The group size.
    Returns
    -------
    Tuple[NDArray]
        The pair of (group_indptr, tile_indices, mask).
    """
    return _ffi_api.ConDense(indptr_nd, indices_nd, t, g)  # type: ignore 
[docs]def csf_to_ell3d(
    csf_indptr_0, csf_indices_0, csf_indptr_1, csf_indices_1, nnz_rows_bkt, nnz_cols_bkt
):
    """Convert CSF format to composable ELL format in 3-dimensional setting (HeteroGraphs).
    Parameters
    ----------
    csf_indptr_0 : NDArray
        Level 0 indptr array in CSF format.
    csf_indices_0 : NDArray
        Level 0 indices array in CSF format.
    csf_indptr_1 : NDArray
        Level 1 indptr array in CSF format.
    csf_indices_1 : NDArray
        Level 1 indices array in CSF format.
    num_rows_bkt : List[int]
        Number of non-zero rows bucket.
    nnz_cols_bkt : List[int]
        Number of non-zero columns bucket.
    Returns
    -------
    Tuple[List[NDArray]]
        (indptr, row_indices, col_indices, mask)
        Each one is a list of NDArray, with length #rels.
    """
    return _ffi_api.CSFToELL3D(
        csf_indptr_0, csf_indices_0, csf_indptr_1, csf_indices_1, nnz_rows_bkt, nnz_cols_bkt
    )