[cig-commits] r6981 - short/3D/PyLith/trunk/modulesrc/solver

knepley at geodynamics.org knepley at geodynamics.org
Mon May 28 14:46:14 PDT 2007


Author: knepley
Date: 2007-05-28 14:46:14 -0700 (Mon, 28 May 2007)
New Revision: 6981

Modified:
   short/3D/PyLith/trunk/modulesrc/solver/solver.pyxe.src
Log:
Fixed solver module


Modified: short/3D/PyLith/trunk/modulesrc/solver/solver.pyxe.src
===================================================================
--- short/3D/PyLith/trunk/modulesrc/solver/solver.pyxe.src	2007-05-27 11:25:19 UTC (rev 6980)
+++ short/3D/PyLith/trunk/modulesrc/solver/solver.pyxe.src	2007-05-28 21:46:14 UTC (rev 6981)
@@ -66,6 +66,8 @@
   cdef readonly object handle # PyCObject holding pointer to C++ object
   cdef readonly object name # Identifier for object base type
   cdef void* vecScatterVptr # Handle to VecScatter
+  cdef void* vecInVptr # Handle to solver rhs
+  cdef void* vecOutVptr # Handle to solver solution
 
   def __init__(self):
     """
@@ -75,6 +77,8 @@
     self.thisptr = NULL
     self.name = "pylith_solver_Solver"
     self.vecScatterVptr = NULL
+    self.vecInVptr = NULL
+    self.vecOutVptr = NULL
 
     # create shim for constructor
     #embed{ void* KSP_create()
@@ -111,10 +115,12 @@
     Initialzie solver.
     """
     # create shim for method 'initialize'
-    #embed{ void Solver_initialize(void* objVptr, void** scatterVptr, void* meshVptr, void* fieldVptr)
+    #embed{ void Solver_initialize(void* objVptr, void** scatterVptr, void** vecInVptr, void** vecOutVptr, void* meshVptr, void* fieldVptr)
     try {
       assert(0 != objVptr);
       assert(0 != scatterVptr);
+      assert(0 != vecInVptr);
+      assert(0 != vecOutVptr);
       assert(0 != meshVptr);
       assert(0 != fieldVptr);
       ALE::Obj<ALE::Mesh>* mesh =
@@ -135,6 +141,38 @@
         } // if
       } // if
       *scatterVptr = (void *) scatter;
+      Vec in;
+      const ALE::Obj<ALE::Mesh::order_type>& order = (*mesh)->getFactory()->getGlobalOrder((*mesh), "default", (*field));
+
+      err = VecCreate((*mesh)->comm(), &in);
+      err = VecSetSizes(in, order->getLocalSize(), order->getGlobalSize());
+      err = VecSetFromOptions(in);
+      if (err) {
+        PetscError(__LINE__,__FUNCT__,__FILE__,__SDIR__,err,0," ");
+        throw std::runtime_error("Could not create vector.");
+      } // if
+      if (0 != *vecInVptr) {
+        err = VecDestroy((Vec) *vecInVptr);
+        if (err) {
+          PetscError(__LINE__,__FUNCT__,__FILE__,__SDIR__,err,0," ");
+          throw std::runtime_error("Could not destroy vector.");
+        } // if
+      } // if
+      *vecInVptr = (void *) in;
+      Vec out;
+      err = VecDuplicate(in, &out);
+      if (err) {
+        PetscError(__LINE__,__FUNCT__,__FILE__,__SDIR__,err,0," ");
+        throw std::runtime_error("Could not create vector.");
+      } // if
+      if (0 != *vecOutVptr) {
+        err = VecDestroy((Vec) *vecOutVptr);
+        if (err) {
+          PetscError(__LINE__,__FUNCT__,__FILE__,__SDIR__,err,0," ");
+          throw std::runtime_error("Could not destroy vector.");
+        } // if
+      } // if
+      *vecOutVptr = (void *) out;
     } catch (const std::exception& err) {
       PyErr_SetString(PyExc_RuntimeError,
                       const_cast<char*>(err.what()));
@@ -151,6 +189,7 @@
     cdef void* fieldVptr
     fieldVptr = PyCObject_AsVoidPtr(field)
     Solver_initialize(self.thisptr, &self.vecScatterVptr,
+                      &self.vecInVptr, &self.vecOutVptr,
                       ptrFromHandle(mesh), fieldVptr)
     return
 
@@ -178,7 +217,7 @@
     Solve linear system.
     """
     # create shim for method 'solve'
-    #embed{ int SolverLinear_solve(void* objVptr, void* fieldOutVptr, void* jacobianVptr, void* fieldInVptr, void* scatterVptr)
+    #embed{ int SolverLinear_solve(void* objVptr, void* fieldOutVptr, void* jacobianVptr, void* fieldInVptr, void* scatterVptr, void* vecInVptr, void* vecOutVptr)
     typedef ALE::Mesh::real_section_type real_section_type;
     try {
       assert(0 != objVptr);
@@ -186,6 +225,8 @@
       assert(0 != jacobianVptr);
       assert(0 != fieldInVptr);
       assert(0 != scatterVptr);
+      assert(0 != vecInVptr);
+      assert(0 != vecOutVptr);
 
       KSP* ksp = (KSP*) objVptr;
       ALE::Obj<real_section_type>* fieldOut =
@@ -194,21 +235,14 @@
       ALE::Obj<real_section_type>* fieldIn =
         (ALE::Obj<real_section_type>*) fieldInVptr;
       VecScatter scatter = (VecScatter) scatterVptr;
+      Vec vecIn = (Vec) vecInVptr;
+      Vec vecOut = (Vec) vecOutVptr;
 
       PetscErrorCode err = 0;
       Vec localVec;
-
-      /** :QUESTION:
-       *
-       * Declare these here or hold them in the Solver object so that
-       * we can reuse them?
-       */
-      Vec vecIn;
-      Vec vecOut;
       
-      err = VecCreateSeqWithArray(PETSC_COMM_SELF, (*fieldIn)->size(),
+      err = VecCreateSeqWithArray(PETSC_COMM_SELF, (*fieldIn)->sizeWithBC(),
                                   (*fieldIn)->restrict(), &localVec);CHKERRQ(err);
-      err = VecDuplicate(localVec, &vecIn); CHKERRQ(err);
       err = VecScatterBegin(scatter, localVec, vecIn, INSERT_VALUES, SCATTER_FORWARD
                             );CHKERRQ(err);
       err = VecScatterEnd(scatter, localVec, vecIn, INSERT_VALUES, SCATTER_FORWARD
@@ -216,9 +250,8 @@
       err = VecDestroy(localVec); CHKERRQ(err);
       err = KSPSetOperators(*ksp, *jacobian, *jacobian,
                             DIFFERENT_NONZERO_PATTERN); CHKERRQ(err);
-      err = VecDuplicate(vecIn, &vecOut); CHKERRQ(err);
       err = KSPSolve(*ksp, vecIn, vecOut); CHKERRQ(err);
-      err = VecCreateSeqWithArray(PETSC_COMM_SELF, (*fieldOut)->size(),
+      err = VecCreateSeqWithArray(PETSC_COMM_SELF, (*fieldOut)->sizeWithBC(),
                                 (*fieldOut)->restrict(), &localVec);CHKERRQ(err);
       err = VecScatterBegin(scatter, vecOut, localVec, INSERT_VALUES, SCATTER_REVERSE
                             ); CHKERRQ(err);
@@ -241,7 +274,8 @@
     jacobianVptr = PyCObject_AsVoidPtr(jacobian)
     fieldInVptr = PyCObject_AsVoidPtr(fieldIn)
     SolverLinear_solve(self.thisptr, fieldOutVptr, jacobianVptr,
-                       fieldInVptr, self.vecScatterVptr)
+                       fieldInVptr, self.vecScatterVptr,
+                       self.vecInVptr, self.vecOutVptr)
     return
 
 



More information about the cig-commits mailing list