diff --git a/src/pl/plpython/expected/plpython_setof.out b/src/pl/plpython/expected/plpython_setof.out index 62b8a454a3..308d2abb7f 100644 --- a/src/pl/plpython/expected/plpython_setof.out +++ b/src/pl/plpython/expected/plpython_setof.out @@ -124,6 +124,35 @@ SELECT test_setof_spi_in_iterator(); World (4 rows) +-- set-returning function that modifies its parameters +CREATE OR REPLACE FUNCTION ugly(x int, lim int) RETURNS SETOF int AS $$ +global x +while x <= lim: + yield x + x = x + 1 +$$ LANGUAGE plpythonu; +SELECT ugly(1, 5); + ugly +------ + 1 + 2 + 3 + 4 + 5 +(5 rows) + +-- interleaved execution of such a function +SELECT ugly(1,3), ugly(7,8); + ugly | ugly +------+------ + 1 | 7 + 2 | 8 + 3 | 7 + 1 | 8 + 2 | 7 + 3 | 8 +(6 rows) + -- returns set of named-composite-type tuples CREATE OR REPLACE FUNCTION get_user_records() RETURNS SETOF users diff --git a/src/pl/plpython/expected/plpython_spi.out b/src/pl/plpython/expected/plpython_spi.out index e715ee5393..dbde36f841 100644 --- a/src/pl/plpython/expected/plpython_spi.out +++ b/src/pl/plpython/expected/plpython_spi.out @@ -57,6 +57,15 @@ for r in rv: return seq ' LANGUAGE plpythonu; +CREATE FUNCTION spi_recursive_sum(a int) RETURNS int + AS +'r = 0 +if a > 1: + r = plpy.execute("SELECT spi_recursive_sum(%d) as a" % (a-1))[0]["a"] +return a + r +' + LANGUAGE plpythonu; +-- -- spi and nested calls -- select nested_call_one('pass this along'); @@ -112,6 +121,12 @@ SELECT join_sequences(sequences) FROM sequences ---------------- (0 rows) +SELECT spi_recursive_sum(10); + spi_recursive_sum +------------------- + 55 +(1 row) + -- -- plan and result objects -- diff --git a/src/pl/plpython/plpy_exec.c b/src/pl/plpython/plpy_exec.c index 24aed011e4..25e4744c7d 100644 --- a/src/pl/plpython/plpy_exec.c +++ b/src/pl/plpython/plpy_exec.c @@ -26,8 +26,21 @@ #include "plpy_subxactobject.h" +/* saved state for a set-returning function */ +typedef struct PLySRFState +{ + PyObject *iter; /* Python iterator producing results */ + PLySavedArgs *savedargs; /* function argument values */ + MemoryContextCallback callback; /* for releasing refcounts when done */ +} PLySRFState; + static PyObject *PLy_function_build_args(FunctionCallInfo fcinfo, PLyProcedure *proc); -static void PLy_function_delete_args(PLyProcedure *proc); +static PLySavedArgs *PLy_function_save_args(PLyProcedure *proc); +static void PLy_function_restore_args(PLyProcedure *proc, PLySavedArgs *savedargs); +static void PLy_function_drop_args(PLySavedArgs *savedargs); +static void PLy_global_args_push(PLyProcedure *proc); +static void PLy_global_args_pop(PLyProcedure *proc); +static void plpython_srf_cleanup_callback(void *arg); static void plpython_return_error_callback(void *arg); static PyObject *PLy_trigger_build_args(FunctionCallInfo fcinfo, PLyProcedure *proc, @@ -36,7 +49,7 @@ static HeapTuple PLy_modify_tuple(PLyProcedure *proc, PyObject *pltd, TriggerData *tdata, HeapTuple otup); static void plpython_trigger_error_callback(void *arg); -static PyObject *PLy_procedure_call(PLyProcedure *proc, char *kargs, PyObject *vargs); +static PyObject *PLy_procedure_call(PLyProcedure *proc, const char *kargs, PyObject *vargs); static void PLy_abort_open_subtransactions(int save_subxact_level); @@ -47,28 +60,65 @@ PLy_exec_function(FunctionCallInfo fcinfo, PLyProcedure *proc) Datum rv; PyObject *volatile plargs = NULL; PyObject *volatile plrv = NULL; + FuncCallContext *volatile funcctx = NULL; + PLySRFState *volatile srfstate = NULL; ErrorContextCallback plerrcontext; + /* + * If the function is called recursively, we must push outer-level + * arguments into the stack. This must be immediately before the PG_TRY + * to ensure that the corresponding pop happens. + */ + PLy_global_args_push(proc); + PG_TRY(); { - if (!proc->is_setof || proc->setof == NULL) + if (proc->is_setof) + { + /* First Call setup */ + if (SRF_IS_FIRSTCALL()) + { + funcctx = SRF_FIRSTCALL_INIT(); + srfstate = (PLySRFState *) + MemoryContextAllocZero(funcctx->multi_call_memory_ctx, + sizeof(PLySRFState)); + /* Immediately register cleanup callback */ + srfstate->callback.func = plpython_srf_cleanup_callback; + srfstate->callback.arg = (void *) srfstate; + MemoryContextRegisterResetCallback(funcctx->multi_call_memory_ctx, + &srfstate->callback); + funcctx->user_fctx = (void *) srfstate; + } + /* Every call setup */ + funcctx = SRF_PERCALL_SETUP(); + Assert(funcctx != NULL); + srfstate = (PLySRFState *) funcctx->user_fctx; + } + + if (srfstate == NULL || srfstate->iter == NULL) { /* - * Simple type returning function or first time for SETOF - * function: actually execute the function. + * Non-SETOF function or first time for SETOF function: build + * args, then actually execute the function. */ plargs = PLy_function_build_args(fcinfo, proc); plrv = PLy_procedure_call(proc, "args", plargs); - if (!proc->is_setof) - { - /* - * SETOF function parameters will be deleted when last row is - * returned - */ - PLy_function_delete_args(proc); - } Assert(plrv != NULL); } + else + { + /* + * Second or later call for a SETOF function: restore arguments in + * globals dict to what they were when we left off. We must do + * this in case multiple evaluations of the same SETOF function + * are interleaved. It's a bit annoying, since the iterator may + * not look at the arguments at all, but we have no way to know + * that. Fortunately this isn't terribly expensive. + */ + if (srfstate->savedargs) + PLy_function_restore_args(proc, srfstate->savedargs); + srfstate->savedargs = NULL; /* deleted by restore_args */ + } /* * If it returns a set, call the iterator to get the next return item. @@ -77,12 +127,11 @@ PLy_exec_function(FunctionCallInfo fcinfo, PLyProcedure *proc) */ if (proc->is_setof) { - bool has_error = false; - ReturnSetInfo *rsi = (ReturnSetInfo *) fcinfo->resultinfo; - - if (proc->setof == NULL) + if (srfstate->iter == NULL) { /* first time -- do checks and setup */ + ReturnSetInfo *rsi = (ReturnSetInfo *) fcinfo->resultinfo; + if (!rsi || !IsA(rsi, ReturnSetInfo) || (rsi->allowedModes & SFRM_ValuePerCall) == 0) { @@ -94,11 +143,12 @@ PLy_exec_function(FunctionCallInfo fcinfo, PLyProcedure *proc) rsi->returnMode = SFRM_ValuePerCall; /* Make iterator out of returned object */ - proc->setof = PyObject_GetIter(plrv); + srfstate->iter = PyObject_GetIter(plrv); + Py_DECREF(plrv); plrv = NULL; - if (proc->setof == NULL) + if (srfstate->iter == NULL) ereport(ERROR, (errcode(ERRCODE_DATATYPE_MISMATCH), errmsg("returned object cannot be iterated"), @@ -106,35 +156,30 @@ PLy_exec_function(FunctionCallInfo fcinfo, PLyProcedure *proc) } /* Fetch next from iterator */ - plrv = PyIter_Next(proc->setof); - if (plrv) - rsi->isDone = ExprMultipleResult; - else - { - rsi->isDone = ExprEndResult; - has_error = PyErr_Occurred() != NULL; - } - - if (rsi->isDone == ExprEndResult) + plrv = PyIter_Next(srfstate->iter); + if (plrv == NULL) { /* Iterator is exhausted or error happened */ - Py_DECREF(proc->setof); - proc->setof = NULL; + bool has_error = (PyErr_Occurred() != NULL); - Py_XDECREF(plargs); - Py_XDECREF(plrv); - - PLy_function_delete_args(proc); + Py_DECREF(srfstate->iter); + srfstate->iter = NULL; if (has_error) PLy_elog(ERROR, "error fetching next item from iterator"); - /* Disconnect from the SPI manager before returning */ - if (SPI_finish() != SPI_OK_FINISH) - elog(ERROR, "SPI_finish failed"); - - fcinfo->isnull = true; - return (Datum) NULL; + /* Pass a null through the data-returning steps below */ + Py_INCREF(Py_None); + plrv = Py_None; + } + else + { + /* + * This won't be last call, so save argument values. We do + * this again each time in case the iterator is changing those + * values. + */ + srfstate->savedargs = PLy_function_save_args(proc); } } @@ -170,7 +215,15 @@ PLy_exec_function(FunctionCallInfo fcinfo, PLyProcedure *proc) else if (plrv == Py_None) { fcinfo->isnull = true; - if (proc->result.is_rowtype < 1) + + /* + * In a SETOF function, the iteration-ending null isn't a real + * value; don't pass it through the input function, which might + * complain. + */ + if (srfstate && srfstate->iter == NULL) + rv = (Datum) 0; + else if (proc->result.is_rowtype < 1) rv = InputFunctionCall(&proc->result.out.d.typfunc, NULL, proc->result.out.d.typioparam, @@ -205,16 +258,28 @@ PLy_exec_function(FunctionCallInfo fcinfo, PLyProcedure *proc) } PG_CATCH(); { + /* Pop old arguments from the stack if they were pushed above */ + PLy_global_args_pop(proc); + Py_XDECREF(plargs); Py_XDECREF(plrv); /* - * If there was an error the iterator might have not been exhausted - * yet. Set it to NULL so the next invocation of the function will - * start the iteration again. + * If there was an error within a SRF, the iterator might not have + * been exhausted yet. Clear it so the next invocation of the + * function will start the iteration again. (This code is probably + * unnecessary now; plpython_srf_cleanup_callback should take care of + * cleanup. But it doesn't hurt anything to do it here.) */ - Py_XDECREF(proc->setof); - proc->setof = NULL; + if (srfstate) + { + Py_XDECREF(srfstate->iter); + srfstate->iter = NULL; + /* And drop any saved args; we won't need them */ + if (srfstate->savedargs) + PLy_function_drop_args(srfstate->savedargs); + srfstate->savedargs = NULL; + } PG_RE_THROW(); } @@ -222,9 +287,27 @@ PLy_exec_function(FunctionCallInfo fcinfo, PLyProcedure *proc) error_context_stack = plerrcontext.previous; + /* Pop old arguments from the stack if they were pushed above */ + PLy_global_args_pop(proc); + Py_XDECREF(plargs); Py_DECREF(plrv); + if (srfstate) + { + /* We're in a SRF, exit appropriately */ + if (srfstate->iter == NULL) + { + /* Iterator exhausted, so we're done */ + SRF_RETURN_DONE(funcctx); + } + else if (fcinfo->isnull) + SRF_RETURN_NEXT_NULL(funcctx); + else + SRF_RETURN_NEXT(funcctx, rv); + } + + /* Plain function, just return the Datum value (possibly null) */ return rv; } @@ -431,17 +514,195 @@ PLy_function_build_args(FunctionCallInfo fcinfo, PLyProcedure *proc) return args; } +/* + * Construct a PLySavedArgs struct representing the current values of the + * procedure's arguments in its globals dict. This can be used to restore + * those values when exiting a recursive call level or returning control to a + * set-returning function. + * + * This would not be necessary except for an ancient decision to make args + * available via the proc's globals :-( ... but we're stuck with that now. + */ +static PLySavedArgs * +PLy_function_save_args(PLyProcedure *proc) +{ + PLySavedArgs *result; + + /* saved args are always allocated in procedure's context */ + result = (PLySavedArgs *) + MemoryContextAllocZero(proc->mcxt, + offsetof(PLySavedArgs, namedargs) + + proc->nargs * sizeof(PyObject *)); + result->nargs = proc->nargs; + + /* Fetch the "args" list */ + result->args = PyDict_GetItemString(proc->globals, "args"); + Py_XINCREF(result->args); + + /* Fetch all the named arguments */ + if (proc->argnames) + { + int i; + + for (i = 0; i < result->nargs; i++) + { + if (proc->argnames[i]) + { + result->namedargs[i] = PyDict_GetItemString(proc->globals, + proc->argnames[i]); + Py_XINCREF(result->namedargs[i]); + } + } + } + + return result; +} + +/* + * Restore procedure's arguments from a PLySavedArgs struct, + * then free the struct. + */ static void -PLy_function_delete_args(PLyProcedure *proc) +PLy_function_restore_args(PLyProcedure *proc, PLySavedArgs *savedargs) +{ + /* Restore named arguments into their slots in the globals dict */ + if (proc->argnames) + { + int i; + + for (i = 0; i < savedargs->nargs; i++) + { + if (proc->argnames[i] && savedargs->namedargs[i]) + { + PyDict_SetItemString(proc->globals, proc->argnames[i], + savedargs->namedargs[i]); + Py_DECREF(savedargs->namedargs[i]); + } + } + } + + /* Restore the "args" object, too */ + if (savedargs->args) + { + PyDict_SetItemString(proc->globals, "args", savedargs->args); + Py_DECREF(savedargs->args); + } + + /* And free the PLySavedArgs struct */ + pfree(savedargs); +} + +/* + * Free a PLySavedArgs struct without restoring the values. + */ +static void +PLy_function_drop_args(PLySavedArgs *savedargs) { int i; - if (!proc->argnames) - return; + /* Drop references for named args */ + for (i = 0; i < savedargs->nargs; i++) + { + Py_XDECREF(savedargs->namedargs[i]); + } - for (i = 0; i < proc->nargs; i++) - if (proc->argnames[i]) - PyDict_DelItemString(proc->globals, proc->argnames[i]); + /* Drop ref to the "args" object, too */ + Py_XDECREF(savedargs->args); + + /* And free the PLySavedArgs struct */ + pfree(savedargs); +} + +/* + * Save away any existing arguments for the given procedure, so that we can + * install new values for a recursive call. This should be invoked before + * doing PLy_function_build_args(). + * + * NB: caller must ensure that PLy_global_args_pop gets invoked once, and + * only once, per successful completion of PLy_global_args_push. Otherwise + * we'll end up out-of-sync between the actual call stack and the contents + * of proc->argstack. + */ +static void +PLy_global_args_push(PLyProcedure *proc) +{ + /* We only need to push if we are already inside some active call */ + if (proc->calldepth > 0) + { + PLySavedArgs *node; + + /* Build a struct containing current argument values */ + node = PLy_function_save_args(proc); + + /* + * Push the saved argument values into the procedure's stack. Once we + * modify either proc->argstack or proc->calldepth, we had better + * return without the possibility of error. + */ + node->next = proc->argstack; + proc->argstack = node; + } + proc->calldepth++; +} + +/* + * Pop old arguments when exiting a recursive call. + * + * Note: the idea here is to adjust the proc's callstack state before doing + * anything that could possibly fail. In event of any error, we want the + * callstack to look like we've done the pop. Leaking a bit of memory is + * tolerable. + */ +static void +PLy_global_args_pop(PLyProcedure *proc) +{ + Assert(proc->calldepth > 0); + /* We only need to pop if we were already inside some active call */ + if (proc->calldepth > 1) + { + PLySavedArgs *ptr = proc->argstack; + + /* Pop the callstack */ + Assert(ptr != NULL); + proc->argstack = ptr->next; + proc->calldepth--; + + /* Restore argument values, then free ptr */ + PLy_function_restore_args(proc, ptr); + } + else + { + /* Exiting call depth 1 */ + Assert(proc->argstack == NULL); + proc->calldepth--; + + /* + * We used to delete the named arguments (but not "args") from the + * proc's globals dict when exiting the outermost call level for a + * function. This seems rather pointless though: nothing can see the + * dict until the function is called again, at which time we'll + * overwrite those dict entries. So don't bother with that. + */ + } +} + +/* + * Memory context deletion callback for cleaning up a PLySRFState. + * We need this in case execution of the SRF is terminated early, + * due to error or the caller simply not running it to completion. + */ +static void +plpython_srf_cleanup_callback(void *arg) +{ + PLySRFState *srfstate = (PLySRFState *) arg; + + /* Release refcount on the iter, if we still have one */ + Py_XDECREF(srfstate->iter); + srfstate->iter = NULL; + /* And drop any saved args; we won't need them */ + if (srfstate->savedargs) + PLy_function_drop_args(srfstate->savedargs); + srfstate->savedargs = NULL; } static void @@ -785,7 +1046,7 @@ plpython_trigger_error_callback(void *arg) /* execute Python code, propagate Python errors to the backend */ static PyObject * -PLy_procedure_call(PLyProcedure *proc, char *kargs, PyObject *vargs) +PLy_procedure_call(PLyProcedure *proc, const char *kargs, PyObject *vargs) { PyObject *rv; int volatile save_subxact_level = list_length(explicit_subtransactions); diff --git a/src/pl/plpython/plpy_procedure.c b/src/pl/plpython/plpy_procedure.c index a0d0792297..70b75f5d95 100644 --- a/src/pl/plpython/plpy_procedure.c +++ b/src/pl/plpython/plpy_procedure.c @@ -188,10 +188,11 @@ PLy_procedure_create(HeapTuple procTup, Oid fn_oid, bool is_trigger) proc->pyname = pstrdup(procName); proc->fn_xmin = HeapTupleHeaderGetRawXmin(procTup->t_data); proc->fn_tid = procTup->t_self; - /* Remember if function is STABLE/IMMUTABLE */ - proc->fn_readonly = - (procStruct->provolatile != PROVOLATILE_VOLATILE); + proc->fn_readonly = (procStruct->provolatile != PROVOLATILE_VOLATILE); + proc->is_setof = procStruct->proretset; PLy_typeinfo_init(&proc->result, proc->mcxt); + proc->src = NULL; + proc->argnames = NULL; for (i = 0; i < FUNC_MAX_ARGS; i++) PLy_typeinfo_init(&proc->args[i], proc->mcxt); proc->nargs = 0; @@ -200,12 +201,11 @@ PLy_procedure_create(HeapTuple procTup, Oid fn_oid, bool is_trigger) Anum_pg_proc_protrftypes, &isnull); proc->trftypes = isnull ? NIL : oid_array_to_list(protrftypes_datum); - proc->code = proc->statics = NULL; + proc->code = NULL; + proc->statics = NULL; proc->globals = NULL; - proc->is_setof = procStruct->proretset; - proc->setof = NULL; - proc->src = NULL; - proc->argnames = NULL; + proc->calldepth = 0; + proc->argstack = NULL; /* * get information required for output conversion of the return value, diff --git a/src/pl/plpython/plpy_procedure.h b/src/pl/plpython/plpy_procedure.h index 9fc8db0797..8ffa38e068 100644 --- a/src/pl/plpython/plpy_procedure.h +++ b/src/pl/plpython/plpy_procedure.h @@ -11,6 +11,15 @@ extern void init_procedure_caches(void); +/* saved arguments for outer recursion level or set-returning function */ +typedef struct PLySavedArgs +{ + struct PLySavedArgs *next; /* linked-list pointer */ + PyObject *args; /* "args" element of globals dict */ + int nargs; /* length of namedargs array */ + PyObject *namedargs[FLEXIBLE_ARRAY_MEMBER]; /* named args */ +} PLySavedArgs; + /* cached procedure data */ typedef struct PLyProcedure { @@ -21,10 +30,9 @@ typedef struct PLyProcedure TransactionId fn_xmin; ItemPointerData fn_tid; bool fn_readonly; + bool is_setof; /* true, if procedure returns result set */ PLyTypeInfo result; /* also used to store info for trigger tuple * type */ - bool is_setof; /* true, if procedure returns result set */ - PyObject *setof; /* contents of result set. */ char *src; /* textual procedure code, after mangling */ char **argnames; /* Argument names */ PLyTypeInfo args[FUNC_MAX_ARGS]; @@ -34,6 +42,8 @@ typedef struct PLyProcedure PyObject *code; /* compiled procedure code */ PyObject *statics; /* data saved across calls, local scope */ PyObject *globals; /* data saved across calls, global scope */ + long calldepth; /* depth of recursive calls of function */ + PLySavedArgs *argstack; /* stack of outer-level call arguments */ } PLyProcedure; /* the procedure cache key */ diff --git a/src/pl/plpython/sql/plpython_setof.sql b/src/pl/plpython/sql/plpython_setof.sql index fe034fba45..16c2eef0ad 100644 --- a/src/pl/plpython/sql/plpython_setof.sql +++ b/src/pl/plpython/sql/plpython_setof.sql @@ -63,6 +63,18 @@ SELECT test_setof_as_iterator(2, null); SELECT test_setof_spi_in_iterator(); +-- set-returning function that modifies its parameters +CREATE OR REPLACE FUNCTION ugly(x int, lim int) RETURNS SETOF int AS $$ +global x +while x <= lim: + yield x + x = x + 1 +$$ LANGUAGE plpythonu; + +SELECT ugly(1, 5); + +-- interleaved execution of such a function +SELECT ugly(1,3), ugly(7,8); -- returns set of named-composite-type tuples CREATE OR REPLACE FUNCTION get_user_records() diff --git a/src/pl/plpython/sql/plpython_spi.sql b/src/pl/plpython/sql/plpython_spi.sql index a882738e0b..87170609da 100644 --- a/src/pl/plpython/sql/plpython_spi.sql +++ b/src/pl/plpython/sql/plpython_spi.sql @@ -52,9 +52,6 @@ return None ' LANGUAGE plpythonu; - - - CREATE FUNCTION join_sequences(s sequences) RETURNS text AS 'if not s["multipart"]: @@ -68,10 +65,16 @@ return seq ' LANGUAGE plpythonu; +CREATE FUNCTION spi_recursive_sum(a int) RETURNS int + AS +'r = 0 +if a > 1: + r = plpy.execute("SELECT spi_recursive_sum(%d) as a" % (a-1))[0]["a"] +return a + r +' + LANGUAGE plpythonu; - - - +-- -- spi and nested calls -- select nested_call_one('pass this along'); @@ -79,15 +82,13 @@ select spi_prepared_plan_test_one('doe'); select spi_prepared_plan_test_one('smith'); select spi_prepared_plan_test_nested('smith'); - - - SELECT join_sequences(sequences) FROM sequences; SELECT join_sequences(sequences) FROM sequences WHERE join_sequences(sequences) ~* '^A'; SELECT join_sequences(sequences) FROM sequences WHERE join_sequences(sequences) ~* '^B'; +SELECT spi_recursive_sum(10); -- -- plan and result objects