474 lines
16 KiB
C++

#include <cusolverDn.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "mmio.h"
/* avoid Windows warnings (for example: strcpy, fscanf, etc.) */
#if defined(_WIN32)
#define _CRT_SECURE_NO_WARNINGS
#endif
/* various __inline__ __device__ function to initialize a T_ELEM */
template <typename T_ELEM> __inline__ T_ELEM cuGet(int);
template <> __inline__ float cuGet<float>(int x) { return float(x); }
template <> __inline__ double cuGet<double>(int x) { return double(x); }
template <> __inline__ cuComplex cuGet<cuComplex>(int x) { return (make_cuComplex(float(x), 0.0f)); }
template <> __inline__ cuDoubleComplex cuGet<cuDoubleComplex>(int x) { return (make_cuDoubleComplex(double(x), 0.0)); }
template <typename T_ELEM> __inline__ T_ELEM cuGet(int, int);
template <> __inline__ float cuGet<float>(int x, int y) { return float(x); }
template <> __inline__ double cuGet<double>(int x, int y) { return double(x); }
template <> __inline__ cuComplex cuGet<cuComplex>(int x, int y) { return make_cuComplex(float(x), float(y)); }
template <> __inline__ cuDoubleComplex cuGet<cuDoubleComplex>(int x, int y)
{
return (make_cuDoubleComplex(double(x), double(y)));
}
template <typename T_ELEM> __inline__ T_ELEM cuGet(float);
template <> __inline__ float cuGet<float>(float x) { return float(x); }
template <> __inline__ double cuGet<double>(float x) { return double(x); }
template <> __inline__ cuComplex cuGet<cuComplex>(float x) { return (make_cuComplex(float(x), 0.0f)); }
template <> __inline__ cuDoubleComplex cuGet<cuDoubleComplex>(float x)
{
return (make_cuDoubleComplex(double(x), 0.0));
}
template <typename T_ELEM> __inline__ T_ELEM cuGet(float, float);
template <> __inline__ float cuGet<float>(float x, float y) { return float(x); }
template <> __inline__ double cuGet<double>(float x, float y) { return double(x); }
template <> __inline__ cuComplex cuGet<cuComplex>(float x, float y) { return (make_cuComplex(float(x), float(y))); }
template <> __inline__ cuDoubleComplex cuGet<cuDoubleComplex>(float x, float y)
{
return (make_cuDoubleComplex(double(x), double(y)));
}
template <typename T_ELEM> __inline__ T_ELEM cuGet(double);
template <> __inline__ float cuGet<float>(double x) { return float(x); }
template <> __inline__ double cuGet<double>(double x) { return double(x); }
template <> __inline__ cuComplex cuGet<cuComplex>(double x) { return (make_cuComplex(float(x), 0.0f)); }
template <> __inline__ cuDoubleComplex cuGet<cuDoubleComplex>(double x)
{
return (make_cuDoubleComplex(double(x), 0.0));
}
template <typename T_ELEM> __inline__ T_ELEM cuGet(double, double);
template <> __inline__ float cuGet<float>(double x, double y) { return float(x); }
template <> __inline__ double cuGet<double>(double x, double y) { return double(x); }
template <> __inline__ cuComplex cuGet<cuComplex>(double x, double y) { return (make_cuComplex(float(x), float(y))); }
template <> __inline__ cuDoubleComplex cuGet<cuDoubleComplex>(double x, double y)
{
return (make_cuDoubleComplex(double(x), double(y)));
}
static void compress_index(const int *Ind, int nnz, int m, int *Ptr, int base)
{
int i;
/* initialize everything to zero */
for (i = 0; i < m + 1; i++) {
Ptr[i] = 0;
}
/* count elements in every row */
Ptr[0] = base;
for (i = 0; i < nnz; i++) {
Ptr[Ind[i] + (1 - base)]++;
}
/* add all the values */
for (i = 0; i < m; i++) {
Ptr[i + 1] += Ptr[i];
}
}
struct cooFormat
{
int i;
int j;
int p; // permutation
};
int cmp_cooFormat_csr(struct cooFormat *s, struct cooFormat *t)
{
if (s->i < t->i) {
return -1;
}
else if (s->i > t->i) {
return 1;
}
else {
return s->j - t->j;
}
}
int cmp_cooFormat_csc(struct cooFormat *s, struct cooFormat *t)
{
if (s->j < t->j) {
return -1;
}
else if (s->j > t->j) {
return 1;
}
else {
return s->i - t->i;
}
}
typedef int (*FUNPTR)(const void *, const void *);
typedef int (*FUNPTR2)(struct cooFormat *s, struct cooFormat *t);
static FUNPTR2 fptr_array[2] = {
cmp_cooFormat_csr,
cmp_cooFormat_csc,
};
static int verify_pattern(int m, int nnz, int *csrRowPtr, int *csrColInd)
{
int i, col, start, end, base_index;
int error_found = 0;
if (nnz != (csrRowPtr[m] - csrRowPtr[0])) {
fprintf(stderr,
"Error (nnz check failed): (csrRowPtr[%d]=%d - csrRowPtr[%d]=%d) != (nnz=%d)\n",
0,
csrRowPtr[0],
m,
csrRowPtr[m],
nnz);
error_found = 1;
}
base_index = csrRowPtr[0];
if ((0 != base_index) && (1 != base_index)) {
fprintf(stderr, "Error (base index check failed): base index = %d\n", base_index);
error_found = 1;
}
for (i = 0; (!error_found) && (i < m); i++) {
start = csrRowPtr[i] - base_index;
end = csrRowPtr[i + 1] - base_index;
if (start > end) {
fprintf(stderr,
"Error (corrupted row): csrRowPtr[%d] (=%d) > csrRowPtr[%d] (=%d)\n",
i,
start + base_index,
i + 1,
end + base_index);
error_found = 1;
}
for (col = start; col < end; col++) {
if (csrColInd[col] < base_index) {
fprintf(stderr, "Error (column vs. base index check failed): csrColInd[%d] < %d\n", col, base_index);
error_found = 1;
}
if ((col < (end - 1)) && (csrColInd[col] >= csrColInd[col + 1])) {
fprintf(
stderr,
"Error (sorting of the column indecis check failed): (csrColInd[%d]=%d) >= (csrColInd[%d]=%d)\n",
col,
csrColInd[col],
col + 1,
csrColInd[col + 1]);
error_found = 1;
}
}
}
return error_found;
}
template <typename T_ELEM>
int loadMMSparseMatrix(char *filename,
char elem_type,
bool csrFormat,
int *m,
int *n,
int *nnz,
T_ELEM **aVal,
int **aRowInd,
int **aColInd,
int extendSymMatrix)
{
MM_typecode matcode;
double *tempVal;
int *tempRowInd, *tempColInd;
double *tval;
int *trow, *tcol;
int *csrRowPtr, *cscColPtr;
int i, j, error, base, count;
struct cooFormat *work;
/* read the matrix */
error = mm_read_mtx_crd(filename, m, n, nnz, &trow, &tcol, &tval, &matcode);
if (error) {
fprintf(stderr, "!!!! can not open file: '%s'\n", filename);
return 1;
}
/* start error checking */
if (mm_is_complex(matcode) && ((elem_type != 'z') && (elem_type != 'c'))) {
fprintf(stderr, "!!!! complex matrix requires type 'z' or 'c'\n");
return 1;
}
if (mm_is_dense(matcode) || mm_is_array(matcode) || mm_is_pattern(matcode) /*|| mm_is_integer(matcode)*/) {
fprintf(stderr, "!!!! dense, array, pattern and integer matrices are not supported\n");
return 1;
}
/* if necessary symmetrize the pattern (transform from triangular to full) */
if ((extendSymMatrix) && (mm_is_symmetric(matcode) || mm_is_hermitian(matcode) || mm_is_skew(matcode))) {
// count number of non-diagonal elements
count = 0;
for (i = 0; i < (*nnz); i++) {
if (trow[i] != tcol[i]) {
count++;
}
}
// allocate space for the symmetrized matrix
tempRowInd = (int *)malloc((*nnz + count) * sizeof(int));
tempColInd = (int *)malloc((*nnz + count) * sizeof(int));
if (mm_is_real(matcode) || mm_is_integer(matcode)) {
tempVal = (double *)malloc((*nnz + count) * sizeof(double));
}
else {
tempVal = (double *)malloc(2 * (*nnz + count) * sizeof(double));
}
// copy the elements regular and transposed locations
for (j = 0, i = 0; i < (*nnz); i++) {
tempRowInd[j] = trow[i];
tempColInd[j] = tcol[i];
if (mm_is_real(matcode) || mm_is_integer(matcode)) {
tempVal[j] = tval[i];
}
else {
tempVal[2 * j] = tval[2 * i];
tempVal[2 * j + 1] = tval[2 * i + 1];
}
j++;
if (trow[i] != tcol[i]) {
tempRowInd[j] = tcol[i];
tempColInd[j] = trow[i];
if (mm_is_real(matcode) || mm_is_integer(matcode)) {
if (mm_is_skew(matcode)) {
tempVal[j] = -tval[i];
}
else {
tempVal[j] = tval[i];
}
}
else {
if (mm_is_hermitian(matcode)) {
tempVal[2 * j] = tval[2 * i];
tempVal[2 * j + 1] = -tval[2 * i + 1];
}
else {
tempVal[2 * j] = tval[2 * i];
tempVal[2 * j + 1] = tval[2 * i + 1];
}
}
j++;
}
}
(*nnz) += count;
// free temporary storage
free(trow);
free(tcol);
free(tval);
}
else {
tempRowInd = trow;
tempColInd = tcol;
tempVal = tval;
}
// life time of (trow, tcol, tval) is over.
// please use COO format (tempRowInd, tempColInd, tempVal)
// use qsort to sort COO format
work = (struct cooFormat *)malloc(sizeof(struct cooFormat) * (*nnz));
if (NULL == work) {
fprintf(stderr, "!!!! allocation error, malloc failed\n");
return 1;
}
for (i = 0; i < (*nnz); i++) {
work[i].i = tempRowInd[i];
work[i].j = tempColInd[i];
work[i].p = i; // permutation is identity
}
if (csrFormat) {
/* create row-major ordering of indices (sorted by row and within each row by column) */
qsort(work, *nnz, sizeof(struct cooFormat), (FUNPTR)fptr_array[0]);
}
else {
/* create column-major ordering of indices (sorted by column and within each column by row) */
qsort(work, *nnz, sizeof(struct cooFormat), (FUNPTR)fptr_array[1]);
}
// (tempRowInd, tempColInd) is sorted either by row-major or by col-major
for (i = 0; i < (*nnz); i++) {
tempRowInd[i] = work[i].i;
tempColInd[i] = work[i].j;
}
// setup base
// check if there is any row/col 0, if so base-0
// check if there is any row/col equal to matrix dimension m/n, if so base-1
int base0 = 0;
int base1 = 0;
for (i = 0; i < (*nnz); i++) {
const int row = tempRowInd[i];
const int col = tempColInd[i];
if ((0 == row) || (0 == col)) {
base0 = 1;
}
if ((*m == row) || (*n == col)) {
base1 = 1;
}
}
if (base0 && base1) {
printf("Error: input matrix is base-0 and base-1 \n");
return 1;
}
base = 0;
if (base1) {
base = 1;
}
/* compress the appropriate indices */
if (csrFormat) {
/* CSR format (assuming row-major format) */
csrRowPtr = (int *)malloc(((*m) + 1) * sizeof(csrRowPtr[0]));
if (!csrRowPtr)
return 1;
compress_index(tempRowInd, *nnz, *m, csrRowPtr, base);
*aRowInd = csrRowPtr;
*aColInd = (int *)malloc((*nnz) * sizeof(int));
}
else {
/* CSC format (assuming column-major format) */
cscColPtr = (int *)malloc(((*n) + 1) * sizeof(cscColPtr[0]));
if (!cscColPtr)
return 1;
compress_index(tempColInd, *nnz, *n, cscColPtr, base);
*aColInd = cscColPtr;
*aRowInd = (int *)malloc((*nnz) * sizeof(int));
}
/* transfrom the matrix values of type double into one of the cusparse library types */
*aVal = (T_ELEM *)malloc((*nnz) * sizeof(T_ELEM));
for (i = 0; i < (*nnz); i++) {
if (csrFormat) {
(*aColInd)[i] = tempColInd[i];
}
else {
(*aRowInd)[i] = tempRowInd[i];
}
if (mm_is_real(matcode) || mm_is_integer(matcode)) {
(*aVal)[i] = cuGet<T_ELEM>(tempVal[work[i].p]);
}
else {
(*aVal)[i] = cuGet<T_ELEM>(tempVal[2 * work[i].p], tempVal[2 * work[i].p + 1]);
}
}
/* check for corruption */
int error_found;
if (csrFormat) {
error_found = verify_pattern(*m, *nnz, *aRowInd, *aColInd);
}
else {
error_found = verify_pattern(*n, *nnz, *aColInd, *aRowInd);
}
if (error_found) {
fprintf(stderr, "!!!! verify_pattern failed\n");
return 1;
}
/* cleanup and exit */
free(work);
free(tempVal);
free(tempColInd);
free(tempRowInd);
return 0;
}
/* specific instantiation */
template int loadMMSparseMatrix<float>(char *filename,
char elem_type,
bool csrFormat,
int *m,
int *n,
int *nnz,
float **aVal,
int **aRowInd,
int **aColInd,
int extendSymMatrix);
template int loadMMSparseMatrix<double>(char *filename,
char elem_type,
bool csrFormat,
int *m,
int *n,
int *nnz,
double **aVal,
int **aRowInd,
int **aColInd,
int extendSymMatrix);
template int loadMMSparseMatrix<cuComplex>(char *filename,
char elem_type,
bool csrFormat,
int *m,
int *n,
int *nnz,
cuComplex **aVal,
int **aRowInd,
int **aColInd,
int extendSymMatrix);
template int loadMMSparseMatrix<cuDoubleComplex>(char *filename,
char elem_type,
bool csrFormat,
int *m,
int *n,
int *nnz,
cuDoubleComplex **aVal,
int **aRowInd,
int **aColInd,
int extendSymMatrix);