# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Format module for sparse tensor algebra."""
from typing import Callable, Dict, List, Union
import tvm._ffi
import tvm.tir
from tvm.runtime import Object
from tvm.tir import IndexMap, _ffi_api
from tvm import IRModule
from tvm.tir.transform import SparseFormatDecompose
[docs]def column_part_hyb(num_rows, num_cols, indptr_nd, indices_nd, num_col_parts, buckets):
"""Partition input CSR matrix by columns and collect rows into buckets according to non zero elements per row.
Parameters
----------
num_rows : int
Number of rows in the CSR matrix.
num_cols : int
Number of columns in the CSR matrix.
indptr : NDArray
The indptr array of CSR matrix.
indices : NDArray
The indices array of CSR matrix.
num_col_parts : int
Number of column partitions.
buckets : List
The bucket sizes array.
Returns
-------
Tuple[List[List[NDArray]]]
The pair of (row_indices, col_indices, mask).
row_indices is stored as a list of lists with shape (num_col_parts, len(buckets)), where the innermost element is an NDArray.
col_indices and mask are stored in the same way.
"""
return _ffi_api.ColumnPartHyb(
num_rows, num_cols, indptr_nd, indices_nd, num_col_parts, buckets # type: ignore
)
[docs]def condense(indptr_nd, indices_nd, t, g):
"""Condense sparse matrix in CSR format to (t x 1) tiles, and group g tiles together.
Parameters
----------
indptr : NDArray
The indptr array of CSR format.
indices : NDArray
The indices array of CSR format.
t : int
The tile size.
g : int
The group size.
Returns
-------
Tuple[NDArray]
The pair of (group_indptr, tile_indices, mask).
"""
return _ffi_api.ConDense(indptr_nd, indices_nd, t, g) # type: ignore
[docs]def csf_to_ell3d(
csf_indptr_0, csf_indices_0, csf_indptr_1, csf_indices_1, nnz_rows_bkt, nnz_cols_bkt
):
"""Convert CSF format to composable ELL format in 3-dimensional setting (HeteroGraphs).
Parameters
----------
csf_indptr_0 : NDArray
Level 0 indptr array in CSF format.
csf_indices_0 : NDArray
Level 0 indices array in CSF format.
csf_indptr_1 : NDArray
Level 1 indptr array in CSF format.
csf_indices_1 : NDArray
Level 1 indices array in CSF format.
num_rows_bkt : List[int]
Number of non-zero rows bucket.
nnz_cols_bkt : List[int]
Number of non-zero columns bucket.
Returns
-------
Tuple[List[NDArray]]
(indptr, row_indices, col_indices, mask)
Each one is a list of NDArray, with length #rels.
"""
return _ffi_api.CSFToELL3D(
csf_indptr_0, csf_indices_0, csf_indptr_1, csf_indices_1, nnz_rows_bkt, nnz_cols_bkt
)