[cig-commits] r12370 - in short/3D/PyLith/trunk: libsrc/topology modulesrc/solver

knepley at geodynamics.org knepley at geodynamics.org
Thu Jul 3 10:29:56 PDT 2008


Author: knepley
Date: 2008-07-03 10:29:56 -0700 (Thu, 03 Jul 2008)
New Revision: 12370

Modified:
   short/3D/PyLith/trunk/libsrc/topology/Distributor.cc
   short/3D/PyLith/trunk/modulesrc/solver/solver.pyxe.src
Log:
Added some nonlinear stuff, added another check for distribution


Modified: short/3D/PyLith/trunk/libsrc/topology/Distributor.cc
===================================================================
--- short/3D/PyLith/trunk/libsrc/topology/Distributor.cc	2008-07-02 21:45:34 UTC (rev 12369)
+++ short/3D/PyLith/trunk/libsrc/topology/Distributor.cc	2008-07-03 17:29:56 UTC (rev 12370)
@@ -67,6 +67,22 @@
       std::cout << "["<<origMesh->commRank()<<"]:   global point " << r_iter->first << " --> " << " local point " << r_iter->second << std::endl;
     }
   }
+  // Check overlap
+  int localSendOverlapSize = 0, sendOverlapSize;
+  int localRecvOverlapSize = 0, recvOverlapSize;
+  for(int p = 0; p < sendMeshOverlap->commSize(); ++p) {
+    localSendOverlapSize += sendMeshOverlap->cone(p)->size();
+    localRecvOverlapSize += recvMeshOverlap->support(p)->size();
+  }
+  MPI_Allreduce(&localSendOverlapSize, &sendOverlapSize, 1, MPI_INT, MPI_SUM, sendMeshOverlap->comm());
+  MPI_Allreduce(&localRecvOverlapSize, &recvOverlapSize, 1, MPI_INT, MPI_SUM, recvMeshOverlap->comm());
+  if(sendOverlapSize != recvOverlapSize) {
+    std::cout <<"["<<sendMeshOverlap->commRank()<<"]: Size mismatch " << sendOverlapSize << " != " << recvOverlapSize << std::endl;
+    sendMeshOverlap->view("Send Overlap");
+    recvMeshOverlap->view("Recv Overlap");
+    throw ALE::Exception("Invalid Overlap");
+  }
+
   // Distribute the coordinates
   const Obj<real_section_type>& coordinates         = origMesh->getRealSection("coordinates");
   const Obj<real_section_type>& parallelCoordinates = (*newMesh)->getRealSection("coordinates");

Modified: short/3D/PyLith/trunk/modulesrc/solver/solver.pyxe.src
===================================================================
--- short/3D/PyLith/trunk/modulesrc/solver/solver.pyxe.src	2008-07-02 21:45:34 UTC (rev 12369)
+++ short/3D/PyLith/trunk/modulesrc/solver/solver.pyxe.src	2008-07-03 17:29:56 UTC (rev 12370)
@@ -13,12 +13,20 @@
 #header{
 #include <petscmesh.hh>
 #include <petscksp.h>
+#include <petscsnes.h>
 
 #include "pylith/utils/sievetypes.hh" // USES PETSc Mesh
 
 #include <assert.h>
 #include <stdexcept>
 #include <Python.h>
+
+typedef struct {
+  ALE::Obj<pylith::Mesh::real_section_type> &fieldIn;
+  ALE::Obj<pylith::Mesh::real_section_type> &fieldOut;
+  VecScatter scatter;
+} PylithSolverStruct;
+
 #}header
 
 # ----------------------------------------------------------------------
@@ -63,7 +71,92 @@
   KSP_destructor_cpp(obj)
   return
 
+cdef void SNES_destructor(void* obj):
+  """
+  Destroy SNES object.
+  """
+  # create shim for destructor
+  #embed{ void SNES_destructor_cpp(void* objVptr)
+  try {
+    SNES* snes = (SNES*) objVptr;
+    PetscErrorCode err = SNESDestroy(*snes);
+    if (err) {
+      PetscError(__LINE__,__FUNCT__,__FILE__,__SDIR__,err,0," ");
+      throw std::runtime_error("Could not destroy SNES object.");
+    } // if
+    delete snes;
+  } catch (const std::exception& err) {
+    PyErr_SetString(PyExc_RuntimeError,
+                    const_cast<char*>(err.what()));
+  } catch (const ALE::Exception& err) {
+    PyErr_SetString(PyExc_RuntimeError,
+                    const_cast<char*>(err.msg().c_str()));
+  } catch (...) {
+    PyErr_SetString(PyExc_RuntimeError,
+                    "Caught unknown C++ exception.");
+  } // try/catch
+  #}embed
+  SNES_destructor_cpp(obj)
+  return
 
+cdef void SolverStruct_destructor(void* obj):
+  """
+  Destroy solver structure object.
+  """
+  # create shim for destructor
+  #embed{ void SolverStruct_destructor_cpp(void* objVptr)
+  try {
+    PylithSolverStruct* s = (PylithSolverStruct*) objVptr;
+    PetscErrorCode err = PetscFree(s);
+    if (err) {
+      PetscError(__LINE__,__FUNCT__,__FILE__,__SDIR__,err,0," ");
+      throw std::runtime_error("Could not destroy solver structure.");
+    } // if
+  } catch (const std::exception& err) {
+    PyErr_SetString(PyExc_RuntimeError,
+                    const_cast<char*>(err.what()));
+  } catch (const ALE::Exception& err) {
+    PyErr_SetString(PyExc_RuntimeError,
+                    const_cast<char*>(err.msg().c_str()));
+  } catch (...) {
+    PyErr_SetString(PyExc_RuntimeError,
+                    "Caught unknown C++ exception.");
+  } // try/catch
+  #}embed
+  SolverStruct_destructor_cpp(obj)
+  return
+
+#header{
+PetscErrorCode PyLith_IntegrateResidual(SNES snes, Vec vecIn, Vec vecOut, void *ctx) {
+  // Get fields, scatter, and fnuction from ctx
+  PylithSolverStruct *s = (PylithSolverStruct *) ctx;
+  Vec                 localVec;
+  PetscErrorCode      ierr;
+
+  PetscFunctionBegin;
+  // TODO: Evaluate material properties for new solution guess (needed for line search)
+
+  ierr = VecCreateSeqWithArray(PETSC_COMM_SELF, s->fieldIn->sizeWithBC(), s->fieldIn->restrictSpace(), &localVec);CHKERRQ(ierr);
+  ierr = VecScatterBegin(s->scatter, vecIn, localVec, INSERT_VALUES, SCATTER_REVERSE); CHKERRQ(ierr);
+  ierr = VecScatterEnd(s->scatter, vecIn, localVec, INSERT_VALUES, SCATTER_REVERSE);CHKERRQ(ierr);
+  ierr = VecDestroy(localVec);CHKERRQ(ierr);
+
+  // TODO: pylith::integrateResidual(s->fieldIn, s->fieldOut);
+  //  residual = self.fields.getReal("residual")
+  //  residual->zero();
+  //  for integrator in self.integrators:
+  //    integrator.timeStep(dt)
+  //    integrator.integrateResidual(residual, t, self.fields)
+  //  bindings.completeSection(self.mesh.cppHandle, residual)
+
+  ierr = VecCreateSeqWithArray(PETSC_COMM_SELF, s->fieldOut->sizeWithBC(), s->fieldOut->restrictSpace(), &localVec);CHKERRQ(ierr);
+  ierr = VecScatterBegin(s->scatter, localVec, vecOut, INSERT_VALUES, SCATTER_FORWARD);CHKERRQ(ierr);
+  ierr = VecScatterEnd(s->scatter, localVec, vecOut, INSERT_VALUES, SCATTER_FORWARD);CHKERRQ(ierr);
+  ierr = VecDestroy(localVec);CHKERRQ(ierr);
+  PetscFunctionReturn(0);
+}
+#}header
+
 # ----------------------------------------------------------------------
 cdef class Solver:
 
@@ -84,43 +177,12 @@
     self.vecScatterVptr = NULL
     self.vecInVptr = NULL
     self.vecOutVptr = NULL
-
-    # create shim for constructor
-    #embed{ void* KSP_create()
-    void* result = 0;
-    try {
-      KSP* ksp = new KSP;
-      PetscErrorCode err = KSPCreate(PETSC_COMM_WORLD, ksp);
-      if (err) {
-        PetscError(__LINE__,__FUNCT__,__FILE__,__SDIR__,err,0," ");
-        throw std::runtime_error("Could not create KSP object.");
-      } // if
-      err = KSPSetFromOptions(*ksp);
-      if (err) {
-        PetscError(__LINE__,__FUNCT__,__FILE__,__SDIR__,err,0," ");
-        throw std::runtime_error("Could not set KSP options.");
-      } // if
-      result = (void*) ksp;
-    } catch (const std::exception& err) {
-      PyErr_SetString(PyExc_RuntimeError,
-                      const_cast<char*>(err.what()));
-    } catch (const ALE::Exception& err) {
-      PyErr_SetString(PyExc_RuntimeError,
-                      const_cast<char*>(err.msg().c_str()));
-    } catch (...) {
-      PyErr_SetString(PyExc_RuntimeError,
-                      "Caught unknown C++ exception.");
-    } // try/catch
-    return result;
-    #}embed
-    self.thisptr = KSP_create()
-    self.handle = self._createHandle()
     return
 
 
   def initialize(self, mesh, field):
     """
-    Initialzie solver.
+    Initialize solver.
     """
     # create shim for method 'initialize'
     #embed{ void Solver_initialize(void* objVptr, void** scatterVptr, void** vecInVptr, void** vecOutVptr, void* meshVptr, void* fieldVptr)
@@ -131,8 +193,7 @@
       assert(0 != vecOutVptr);
       assert(0 != meshVptr);
       assert(0 != fieldVptr);
-      ALE::Obj<pylith::Mesh>* mesh =
-        (ALE::Obj<pylith::Mesh>*) meshVptr;
+      ALE::Obj<pylith::Mesh>* mesh = (ALE::Obj<pylith::Mesh>*) meshVptr;
       ALE::Obj<pylith::Mesh::real_section_type>* field =
         (ALE::Obj<pylith::Mesh::real_section_type>*) fieldVptr;
       VecScatter scatter;
@@ -211,6 +272,53 @@
     Set initial guess nonzero flag.
     (true if nonzero initial guess, false if initial guess should be zero).
     """
+    return
+
+
+# ----------------------------------------------------------------------
+cdef class SolverLinear(Solver):
+
+  def __init__(self):
+    """
+    Constructor.
+    """
+    Solver.__init__(self)
+
+    # create shim for constructor
+    #embed{ void* KSP_create()
+    void* result = 0;
+    try {
+      KSP* ksp = new KSP;
+      PetscErrorCode err = KSPCreate(PETSC_COMM_WORLD, ksp);
+      if (err) {
+        PetscError(__LINE__,__FUNCT__,__FILE__,__SDIR__,err,0," ");
+        throw std::runtime_error("Could not create KSP object.");
+      } // if
+      err = KSPSetFromOptions(*ksp);
+      if (err) {
+        PetscError(__LINE__,__FUNCT__,__FILE__,__SDIR__,err,0," ");
+        throw std::runtime_error("Could not set KSP options.");
+      } // if
+      result = (void*) ksp;
+    } catch (const std::exception& err) {
+      PyErr_SetString(PyExc_RuntimeError, const_cast<char*>(err.what()));
+    } catch (const ALE::Exception& err) {
+      PyErr_SetString(PyExc_RuntimeError, const_cast<char*>(err.msg().c_str()));
+    } catch (...) {
+      PyErr_SetString(PyExc_RuntimeError, "Caught unknown C++ exception.");
+    } // try/catch
+    return result;
+    #}embed
+    self.thisptr = KSP_create()
+    self.handle  = self._createHandle()
+    return
+
+
+  def setInitialGuessNonzero(self, value):
+    """
+    Set initial guess nonzero flag.
+    (true if nonzero initial guess, false if initial guess should be zero).
+    """
     # create shim for method 'setInitialGuessNonzero'
     #embed{ void Solver_setInitialGuessNonzero(void* objVptr, int value)
     try {
@@ -219,40 +327,17 @@
       PetscTruth flag = (value) ? PETSC_TRUE : PETSC_FALSE;
       KSPSetInitialGuessNonzero(*ksp, flag);
     } catch (const std::exception& err) {
-      PyErr_SetString(PyExc_RuntimeError,
-                      const_cast<char*>(err.what()));
+      PyErr_SetString(PyExc_RuntimeError, const_cast<char*>(err.what()));
     } catch (const ALE::Exception& err) {
-      PyErr_SetString(PyExc_RuntimeError,
-                      const_cast<char*>(err.msg().c_str()));
+      PyErr_SetString(PyExc_RuntimeError, const_cast<char*>(err.msg().c_str()));
     } catch (...) {
-      PyErr_SetString(PyExc_RuntimeError,
-                      "Caught unknown C++ exception.");
+      PyErr_SetString(PyExc_RuntimeError, "Caught unknown C++ exception.");
     } // try/catch      
     #}embed
 
     Solver_setInitialGuessNonzero(self.thisptr, value)
     return
-    
 
-
-  def _createHandle(self):
-    """
-    Wrap pointer to C++ object in PyCObject.
-    """
-    return PyCObject_FromVoidPtr(self.thisptr, KSP_destructor)
-
-
-# ----------------------------------------------------------------------
-cdef class SolverLinear(Solver):
-
-  def __init__(self):
-    """
-    Constructor.
-    """
-    Solver.__init__(self)
-    return
-
-
   def solve(self, fieldOut, jacobian, fieldIn):
     """
     Solve linear system.
@@ -323,24 +408,83 @@
                        self.vecInVptr, self.vecOutVptr)
     return
 
+  def _createHandle(self):
+    """
+    Wrap pointer to C++ object in PyCObject.
+    """
+    return PyCObject_FromVoidPtr(self.thisptr, KSP_destructor)
 
+
 # ----------------------------------------------------------------------
 cdef class SolverNonlinear(Solver):
+  cdef void* solverStruct # Pointer to C++ object
+  cdef readonly object solverStructHandle # PyCObject holding pointer to C++ object
 
   def __init__(self):
     """
     Constructor.
     """
     Solver.__init__(self)
+
+    # create shim for constructor
+    #embed{ void* SNES_create()
+    void* result = 0;
+    try {
+      SNES* snes = new SNES;
+      PetscErrorCode err = SNESCreate(PETSC_COMM_WORLD, snes);
+      if (err) {
+        PetscError(__LINE__,__FUNCT__,__FILE__,__SDIR__,err,0," ");
+        throw std::runtime_error("Could not create SNES object.");
+      } // if
+      err = SNESSetFromOptions(*snes);
+      if (err) {
+        PetscError(__LINE__,__FUNCT__,__FILE__,__SDIR__,err,0," ");
+        throw std::runtime_error("Could not set SNES options.");
+      } // if
+      result = (void*) snes;
+    } catch (const std::exception& err) {
+      PyErr_SetString(PyExc_RuntimeError, const_cast<char*>(err.what()));
+    } catch (const ALE::Exception& err) {
+      PyErr_SetString(PyExc_RuntimeError, const_cast<char*>(err.msg().c_str()));
+    } catch (...) {
+      PyErr_SetString(PyExc_RuntimeError, "Caught unknown C++ exception.");
+    } // try/catch
+    return result;
+    #}embed
+    self.thisptr = SNES_create()
+    self.handle  = self._createHandle()
+    self.solverStructHandle = self._createStruct()
     return
 
 
+  def initialize(self, mesh, field):
+    """
+    Initialize solver.
+    """
+    Solver.initialize(self, mesh, field)
+    # create shim for constructor
+    #embed{ void* SNES_setFunction(void* objVptr, void* vecOutVptr, void* solverStructVptr)
+    assert(0 != objVptr);
+    assert(0 != vecOutVptr);
+    assert(0 != solverStructVptr);
+
+    SNES*          snes   = (SNES*) objVptr;
+    Vec            vecOut = (Vec) vecOutVptr;
+    PetscErrorCode err = SNESSetFunction(*snes, vecOut, PyLith_IntegrateResidual, solverStructVptr);
+    if (err) {
+      PetscError(__LINE__,__FUNCT__,__FILE__,__SDIR__,err,0," ");
+      throw std::runtime_error("Could not set SNES function.");
+    } // if
+    #}embed
+    SNES_setFunction(self.thisptr, self.vecOutVptr, self.solverStruct)
+    return
+
   def solve(self, fieldOut, fieldIn):
     """
     Solve nonlinear system.
     """
     # create shim for method 'solve'
-    #embed{ int SolverNonlinear_solve(void* objVptr, void* fieldOutVptr, void* fieldInVptr, void* scatterVptr, void* vecInVptr, void* vecOutVptr)
+    #embed{ int SolverNonlinear_solve(void* objVptr, void* fieldOutVptr, void* fieldInVptr, void* scatterVptr, void* vecInVptr, void* vecOutVptr, void* solverStructVptr)
     typedef pylith::Mesh::real_section_type real_section_type;
     PetscErrorCode err = 0;
     try {
@@ -350,35 +494,29 @@
       assert(0 != scatterVptr);
       assert(0 != vecInVptr);
       assert(0 != vecOutVptr);
+      assert(0 != solverStructVptr);
 
-      KSP* ksp = (KSP*) objVptr;
-      ALE::Obj<real_section_type>* fieldOut =
-        (ALE::Obj<real_section_type>*) fieldOutVptr;
-      ALE::Obj<real_section_type>* fieldIn =
-        (ALE::Obj<real_section_type>*) fieldInVptr;
-      VecScatter scatter = (VecScatter) scatterVptr;
-      Vec vecIn = (Vec) vecInVptr;
-      Vec vecOut = (Vec) vecOutVptr;
+      SNES* snes = (SNES*) objVptr;
+      ALE::Obj<real_section_type>* fieldOut = (ALE::Obj<real_section_type>*) fieldOutVptr;
+      ALE::Obj<real_section_type>* fieldIn  = (ALE::Obj<real_section_type>*) fieldInVptr;
+      VecScatter                   scatter  = (VecScatter) scatterVptr;
+      Vec                          vecIn    = (Vec) vecInVptr;
+      Vec                          vecOut   = (Vec) vecOutVptr;
+      PylithSolverStruct*          s        = (PylithSolverStruct *) solverStructVptr;
+      Vec                          localVec;
 
-      Vec localVec;
-
-      err = VecCreateSeqWithArray(PETSC_COMM_SELF, (*fieldIn)->sizeWithBC(),
-                                  (*fieldIn)->restrictSpace(), &localVec);CHKERRQ(err);
-      err = VecScatterBegin(scatter, localVec, vecIn, INSERT_VALUES, SCATTER_FORWARD
-                            );CHKERRQ(err);
-      err = VecScatterEnd(scatter, localVec, vecIn, INSERT_VALUES, SCATTER_FORWARD
-                          ); CHKERRQ(err);
-      err = VecDestroy(localVec); CHKERRQ(err);
-
-      err = KSPSolve(*ksp, vecIn, vecOut); CHKERRQ(err);
-
-      err = VecCreateSeqWithArray(PETSC_COMM_SELF, (*fieldOut)->sizeWithBC(),
-                                (*fieldOut)->restrictSpace(), &localVec);CHKERRQ(err);
-      err = VecScatterBegin(scatter, vecOut, localVec, INSERT_VALUES, SCATTER_REVERSE
-                            ); CHKERRQ(err);
-      err = VecScatterEnd(scatter, vecOut, localVec, INSERT_VALUES, SCATTER_REVERSE
-                          ); CHKERRQ(err);
-      err = VecDestroy(localVec); CHKERRQ(err);
+      s->fieldIn  = *fieldIn;
+      s->fieldOut = *fieldOut;
+      s->scatter  = scatter;
+      err = VecCreateSeqWithArray(PETSC_COMM_SELF, (*fieldIn)->sizeWithBC(), (*fieldIn)->restrictSpace(), &localVec);CHKERRQ(err);
+      err = VecScatterBegin(scatter, localVec, vecIn, INSERT_VALUES, SCATTER_FORWARD);CHKERRQ(err);
+      err = VecScatterEnd(scatter, localVec, vecIn, INSERT_VALUES, SCATTER_FORWARD);CHKERRQ(err);
+      err = VecDestroy(localVec);CHKERRQ(err);
+      err = SNESSolve(*snes, vecIn, vecOut); CHKERRQ(err);
+      err = VecCreateSeqWithArray(PETSC_COMM_SELF, (*fieldOut)->sizeWithBC(), (*fieldOut)->restrictSpace(), &localVec);CHKERRQ(err);
+      err = VecScatterBegin(scatter, vecOut, localVec, INSERT_VALUES, SCATTER_REVERSE); CHKERRQ(err);
+      err = VecScatterEnd(scatter, vecOut, localVec, INSERT_VALUES, SCATTER_REVERSE);CHKERRQ(err);
+      err = VecDestroy(localVec);CHKERRQ(err);
     } catch (const std::exception& err) {
       PyErr_SetString(PyExc_RuntimeError,
                       const_cast<char*>(err.what()));
@@ -395,11 +533,40 @@
     cdef void* fieldOutVptr
     cdef void* fieldInVptr
     fieldOutVptr = PyCObject_AsVoidPtr(fieldOut)
-    fieldInVptr = PyCObject_AsVoidPtr(fieldIn)
-    SolverNonlinear_solve(self.thisptr, fieldOutVptr,
-                          fieldInVptr, self.vecScatterVptr,
-                          self.vecInVptr, self.vecOutVptr)
+    fieldInVptr  = PyCObject_AsVoidPtr(fieldIn)
+    SolverNonlinear_solve(self.thisptr, fieldOutVptr, fieldInVptr, self.vecScatterVptr, self.vecInVptr, self.vecOutVptr, self.solverStruct)
     return
 
+  def _createHandle(self):
+    """
+    Wrap pointer to C++ object in PyCObject.
+    """
+    return PyCObject_FromVoidPtr(self.thisptr, SNES_destructor)
 
+  def _createStruct(self):
+    """
+    Wrap pointer to C++ object in PyCObject.
+    """
+    # create shim for constructor
+    #embed{ void* SolverStruct_create()
+    void* result = 0;
+    try {
+      PetscErrorCode err = PetscMalloc(sizeof(PylithSolverStruct), &result);
+      if (err) {
+        PetscError(__LINE__,__FUNCT__,__FILE__,__SDIR__,err,0," ");
+        throw std::runtime_error("Could not create solver structure.");
+      } // if
+    } catch (const std::exception& err) {
+      PyErr_SetString(PyExc_RuntimeError, const_cast<char*>(err.what()));
+    } catch (const ALE::Exception& err) {
+      PyErr_SetString(PyExc_RuntimeError, const_cast<char*>(err.msg().c_str()));
+    } catch (...) {
+      PyErr_SetString(PyExc_RuntimeError, "Caught unknown C++ exception.");
+    } // try/catch
+    return result;
+    #}embed
+    self.solverStruct = SolverStruct_create()
+    return PyCObject_FromVoidPtr(self.solverStruct, SolverStruct_destructor)
+
+
 # End of file 



More information about the cig-commits mailing list