cuda-samples/Samples/4_CUDA_Libraries/cuSolverSp_LinearSolver/mmio_wrapper.cpp
2022-01-13 11:35:24 +05:30

530 lines
13 KiB
C++

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include "mmio.h"
#include <cusolverDn.h>
/* avoid Windows warnings (for example: strcpy, fscanf, etc.) */
#if defined(_WIN32)
#define _CRT_SECURE_NO_WARNINGS
#endif
/* various __inline__ __device__ function to initialize a T_ELEM */
template <typename T_ELEM> __inline__ T_ELEM cuGet (int );
template <> __inline__ float cuGet<float >(int x)
{
return float(x);
}
template <> __inline__ double cuGet<double>(int x)
{
return double(x);
}
template <> __inline__ cuComplex cuGet<cuComplex>(int x)
{
return (make_cuComplex( float(x), 0.0f ));
}
template <> __inline__ cuDoubleComplex cuGet<cuDoubleComplex>(int x)
{
return (make_cuDoubleComplex( double(x), 0.0 ));
}
template <typename T_ELEM> __inline__ T_ELEM cuGet (int , int );
template <> __inline__ float cuGet<float >(int x, int y)
{
return float(x);
}
template <> __inline__ double cuGet<double>(int x, int y)
{
return double(x);
}
template <> __inline__ cuComplex cuGet<cuComplex>(int x, int y)
{
return make_cuComplex( float(x), float(y) );
}
template <> __inline__ cuDoubleComplex cuGet<cuDoubleComplex>(int x, int y)
{
return (make_cuDoubleComplex( double(x), double(y) ));
}
template <typename T_ELEM> __inline__ T_ELEM cuGet (float );
template <> __inline__ float cuGet<float >(float x)
{
return float(x);
}
template <> __inline__ double cuGet<double>(float x)
{
return double(x);
}
template <> __inline__ cuComplex cuGet<cuComplex>(float x)
{
return (make_cuComplex( float(x), 0.0f ));
}
template <> __inline__ cuDoubleComplex cuGet<cuDoubleComplex>(float x)
{
return (make_cuDoubleComplex( double(x), 0.0 ));
}
template <typename T_ELEM> __inline__ T_ELEM cuGet (float, float );
template <> __inline__ float cuGet<float >(float x, float y)
{
return float(x);
}
template <> __inline__ double cuGet<double>(float x, float y)
{
return double(x);
}
template <> __inline__ cuComplex cuGet<cuComplex>(float x, float y)
{
return (make_cuComplex( float(x), float(y) ));
}
template <> __inline__ cuDoubleComplex cuGet<cuDoubleComplex>(float x, float y)
{
return (make_cuDoubleComplex( double(x), double(y) ));
}
template <typename T_ELEM> __inline__ T_ELEM cuGet (double );
template <> __inline__ float cuGet<float >(double x)
{
return float(x);
}
template <> __inline__ double cuGet<double>(double x)
{
return double(x);
}
template <> __inline__ cuComplex cuGet<cuComplex>(double x)
{
return (make_cuComplex( float(x), 0.0f ));
}
template <> __inline__ cuDoubleComplex cuGet<cuDoubleComplex>(double x)
{
return (make_cuDoubleComplex( double(x), 0.0 ));
}
template <typename T_ELEM> __inline__ T_ELEM cuGet (double, double );
template <> __inline__ float cuGet<float >(double x, double y)
{
return float(x);
}
template <> __inline__ double cuGet<double>(double x, double y)
{
return double(x);
}
template <> __inline__ cuComplex cuGet<cuComplex>(double x, double y)
{
return (make_cuComplex( float(x), float(y) ));
}
template <> __inline__ cuDoubleComplex cuGet<cuDoubleComplex>(double x, double y)
{
return (make_cuDoubleComplex( double(x), double(y) ));
}
static void compress_index(
const int *Ind,
int nnz,
int m,
int *Ptr,
int base)
{
int i;
/* initialize everything to zero */
for(i=0; i<m+1; i++){
Ptr[i]=0;
}
/* count elements in every row */
Ptr[0]=base;
for(i=0; i<nnz; i++){
Ptr[Ind[i]+(1-base)]++;
}
/* add all the values */
for(i=0; i<m; i++){
Ptr[i+1]+=Ptr[i];
}
}
struct cooFormat {
int i ;
int j ;
int p ; // permutation
};
int cmp_cooFormat_csr( struct cooFormat *s, struct cooFormat *t)
{
if ( s->i < t->i ){
return -1 ;
}
else if ( s->i > t->i ){
return 1 ;
}
else{
return s->j - t->j ;
}
}
int cmp_cooFormat_csc( struct cooFormat *s, struct cooFormat *t)
{
if ( s->j < t->j ){
return -1 ;
}
else if ( s->j > t->j ){
return 1 ;
}
else{
return s->i - t->i ;
}
}
typedef int (*FUNPTR) (const void*, const void*) ;
typedef int (*FUNPTR2) ( struct cooFormat *s, struct cooFormat *t) ;
static FUNPTR2 fptr_array[2] = {
cmp_cooFormat_csr,
cmp_cooFormat_csc,
};
static int verify_pattern(
int m,
int nnz,
int *csrRowPtr,
int *csrColInd)
{
int i, col, start, end, base_index;
int error_found = 0;
if (nnz != (csrRowPtr[m] - csrRowPtr[0])){
fprintf(stderr, "Error (nnz check failed): (csrRowPtr[%d]=%d - csrRowPtr[%d]=%d) != (nnz=%d)\n", 0, csrRowPtr[0], m, csrRowPtr[m], nnz);
error_found = 1;
}
base_index = csrRowPtr[0];
if ((0 != base_index) && (1 != base_index)){
fprintf(stderr, "Error (base index check failed): base index = %d\n", base_index);
error_found = 1;
}
for (i=0; (!error_found) && (i<m); i++){
start = csrRowPtr[i ] - base_index;
end = csrRowPtr[i+1] - base_index;
if (start > end){
fprintf(stderr, "Error (corrupted row): csrRowPtr[%d] (=%d) > csrRowPtr[%d] (=%d)\n", i, start+base_index, i+1, end+base_index);
error_found = 1;
}
for (col=start; col<end; col++){
if (csrColInd[col] < base_index){
fprintf(stderr, "Error (column vs. base index check failed): csrColInd[%d] < %d\n", col, base_index);
error_found = 1;
}
if ((col < (end-1)) && (csrColInd[col] >= csrColInd[col+1])){
fprintf(stderr, "Error (sorting of the column indecis check failed): (csrColInd[%d]=%d) >= (csrColInd[%d]=%d)\n", col, csrColInd[col], col+1, csrColInd[col+1]);
error_found = 1;
}
}
}
return error_found ;
}
template <typename T_ELEM>
int loadMMSparseMatrix(
char *filename,
char elem_type,
bool csrFormat,
int *m,
int *n,
int *nnz,
T_ELEM **aVal,
int **aRowInd,
int **aColInd,
int extendSymMatrix)
{
MM_typecode matcode;
double *tempVal;
int *tempRowInd,*tempColInd;
double *tval;
int *trow,*tcol;
int *csrRowPtr, *cscColPtr;
int i,j,error,base,count;
struct cooFormat *work;
/* read the matrix */
error = mm_read_mtx_crd(filename, m, n, nnz, &trow, &tcol, &tval, &matcode);
if (error) {
fprintf(stderr, "!!!! can not open file: '%s'\n", filename);
return 1;
}
/* start error checking */
if (mm_is_complex(matcode) && ((elem_type != 'z') && (elem_type != 'c'))) {
fprintf(stderr, "!!!! complex matrix requires type 'z' or 'c'\n");
return 1;
}
if (mm_is_dense(matcode) || mm_is_array(matcode) || mm_is_pattern(matcode) /*|| mm_is_integer(matcode)*/){
fprintf(stderr, "!!!! dense, array, pattern and integer matrices are not supported\n");
return 1;
}
/* if necessary symmetrize the pattern (transform from triangular to full) */
if ((extendSymMatrix) && (mm_is_symmetric(matcode) || mm_is_hermitian(matcode) || mm_is_skew(matcode))){
//count number of non-diagonal elements
count=0;
for(i=0; i<(*nnz); i++){
if (trow[i] != tcol[i]){
count++;
}
}
//allocate space for the symmetrized matrix
tempRowInd = (int *)malloc((*nnz + count) * sizeof(int));
tempColInd = (int *)malloc((*nnz + count) * sizeof(int));
if (mm_is_real(matcode) || mm_is_integer(matcode)){
tempVal = (double *)malloc((*nnz + count) * sizeof(double));
}
else{
tempVal = (double *)malloc(2 * (*nnz + count) * sizeof(double));
}
//copy the elements regular and transposed locations
for(j=0, i=0; i<(*nnz); i++){
tempRowInd[j]=trow[i];
tempColInd[j]=tcol[i];
if (mm_is_real(matcode) || mm_is_integer(matcode)){
tempVal[j]=tval[i];
}
else{
tempVal[2*j] =tval[2*i];
tempVal[2*j+1]=tval[2*i+1];
}
j++;
if (trow[i] != tcol[i]){
tempRowInd[j]=tcol[i];
tempColInd[j]=trow[i];
if (mm_is_real(matcode) || mm_is_integer(matcode)){
if (mm_is_skew(matcode)){
tempVal[j]=-tval[i];
}
else{
tempVal[j]= tval[i];
}
}
else{
if(mm_is_hermitian(matcode)){
tempVal[2*j] = tval[2*i];
tempVal[2*j+1]=-tval[2*i+1];
}
else{
tempVal[2*j] = tval[2*i];
tempVal[2*j+1]= tval[2*i+1];
}
}
j++;
}
}
(*nnz)+=count;
//free temporary storage
free(trow);
free(tcol);
free(tval);
}
else{
tempRowInd=trow;
tempColInd=tcol;
tempVal =tval;
}
// life time of (trow, tcol, tval) is over.
// please use COO format (tempRowInd, tempColInd, tempVal)
// use qsort to sort COO format
work = (struct cooFormat *)malloc(sizeof(struct cooFormat)*(*nnz));
if (NULL == work){
fprintf(stderr, "!!!! allocation error, malloc failed\n");
return 1;
}
for(i=0; i<(*nnz); i++){
work[i].i = tempRowInd[i];
work[i].j = tempColInd[i];
work[i].p = i; // permutation is identity
}
if (csrFormat){
/* create row-major ordering of indices (sorted by row and within each row by column) */
qsort(work, *nnz, sizeof(struct cooFormat), (FUNPTR)fptr_array[0] );
}else{
/* create column-major ordering of indices (sorted by column and within each column by row) */
qsort(work, *nnz, sizeof(struct cooFormat), (FUNPTR)fptr_array[1] );
}
// (tempRowInd, tempColInd) is sorted either by row-major or by col-major
for(i=0; i<(*nnz); i++){
tempRowInd[i] = work[i].i;
tempColInd[i] = work[i].j;
}
// setup base
// check if there is any row/col 0, if so base-0
// check if there is any row/col equal to matrix dimension m/n, if so base-1
int base0 = 0;
int base1 = 0;
for(i=0; i<(*nnz); i++){
const int row = tempRowInd[i];
const int col = tempColInd[i];
if ( (0 == row) || (0 == col) ){
base0 = 1;
}
if ( (*m == row) || (*n == col) ){
base1 = 1;
}
}
if ( base0 && base1 ){
printf("Error: input matrix is base-0 and base-1 \n");
return 1;
}
base = 0;
if (base1){
base = 1;
}
/* compress the appropriate indices */
if (csrFormat){
/* CSR format (assuming row-major format) */
csrRowPtr = (int *)malloc(((*m)+1) * sizeof(csrRowPtr[0]));
if (!csrRowPtr) return 1;
compress_index(tempRowInd, *nnz, *m, csrRowPtr, base);
*aRowInd = csrRowPtr;
*aColInd = (int *)malloc((*nnz) * sizeof(int));
}
else {
/* CSC format (assuming column-major format) */
cscColPtr = (int *)malloc(((*n)+1) * sizeof(cscColPtr[0]));
if (!cscColPtr) return 1;
compress_index(tempColInd, *nnz, *n, cscColPtr, base);
*aColInd = cscColPtr;
*aRowInd = (int *)malloc((*nnz) * sizeof(int));
}
/* transfrom the matrix values of type double into one of the cusparse library types */
*aVal = (T_ELEM *)malloc((*nnz) * sizeof(T_ELEM));
for (i=0; i<(*nnz); i++) {
if (csrFormat){
(*aColInd)[i] = tempColInd[i];
}
else{
(*aRowInd)[i] = tempRowInd[i];
}
if (mm_is_real(matcode) || mm_is_integer(matcode)){
(*aVal)[i] = cuGet<T_ELEM>( tempVal[ work[i].p ] );
}
else{
(*aVal)[i] = cuGet<T_ELEM>(tempVal[2*work[i].p], tempVal[2*work[i].p+1]);
}
}
/* check for corruption */
int error_found;
if (csrFormat){
error_found = verify_pattern(*m, *nnz, *aRowInd, *aColInd);
}else{
error_found = verify_pattern(*n, *nnz, *aColInd, *aRowInd);
}
if (error_found){
fprintf(stderr, "!!!! verify_pattern failed\n");
return 1;
}
/* cleanup and exit */
free(work);
free(tempVal);
free(tempColInd);
free(tempRowInd);
return 0;
}
/* specific instantiation */
template int loadMMSparseMatrix<float>(
char *filename,
char elem_type,
bool csrFormat,
int *m,
int *n,
int *nnz,
float **aVal,
int **aRowInd,
int **aColInd,
int extendSymMatrix);
template int loadMMSparseMatrix<double>(
char *filename,
char elem_type,
bool csrFormat,
int *m,
int *n,
int *nnz,
double **aVal,
int **aRowInd,
int **aColInd,
int extendSymMatrix);
template int loadMMSparseMatrix<cuComplex>(
char *filename,
char elem_type,
bool csrFormat,
int *m,
int *n,
int *nnz,
cuComplex **aVal,
int **aRowInd,
int **aColInd,
int extendSymMatrix);
template int loadMMSparseMatrix<cuDoubleComplex>(
char *filename,
char elem_type,
bool csrFormat,
int *m,
int *n,
int *nnz,
cuDoubleComplex **aVal,
int **aRowInd,
int **aColInd,
int extendSymMatrix);