Pystan3 in UDF gives error

Upgrading from pystan 2.19.1.1 to pystan 3.9.1, gives the following error when using UDF in pyspark. Running the same function without UDF works fine. This is getting executed in azure Databricks.

import stan
import nest_asyncio
from pyspark.sql import functions as F
from pyspark.sql import types as T

nest_asyncio.apply()

df_input = spark.createDataFrame(
  [
    (8, [28,  8, -3,  7, -1,  1, 18, 12], [15, 10, 16, 11,  9, 11, 10, 18], 'West_Sobeys'),
    (7, [28,  8, -3,  7, -1,  1, 18, 12], [15, 10, 16, 11,  9, 11, 10, 21], 'Ontario_Sobeys')
  ],
  ('J', 'y', 'sigma', 'exec_id')
)

pdf_input = spark.createDataFrame(
  [
    (8, [28,  8, -3,  7, -1,  1, 18, 12], [15, 10, 16, 11,  9, 11, 10, 18]),
  ],
  ('J', 'y', 'sigma')
).toPandas()

schools_code = """
data {
  int<lower=0> J;         // number of schools
  array[J] real y;              // estimated treatment effects
  array[J] real<lower=0> sigma; // standard error of effect estimates
}
parameters {
  real mu;                // population treatment effect
  real<lower=0> tau;      // standard deviation in treatment effects
  vector[J] eta;          // unscaled deviation from mu by school
}
transformed parameters {
  vector[J] theta = mu + tau * eta;        // school treatment effects
}
model {
  target += normal_lpdf(eta | 0, 1);       // prior log-density
  target += normal_lpdf(y | theta, sigma); // log-likelihood
}
"""


def call_stan_build(pdf):
  posterior = stan.build(schools_code, data=pdf.to_dict(orient='records')[0])

  fit = posterior.sample(num_chains=4, num_samples=1000)
  eta = fit["eta"]  # array with shape (8, 4000)
  df = fit.to_frame()  # pandas `DataFrame, requires pandas
  return df[['mu']]

This works

pdf_out = call_stan_build(pdf_input)

pdf_out.show()

This doesn’t works

df_out = df_input.groupby("exec_id").applyInPandas(
    call_stan_build, schema="mu float")

df_out.show()

Here is the error:

PythonException                           Traceback (most recent call last)
File <command-1383263685039439>, line 4
      1 df_out = df_input.groupby("exec_id").applyInPandas(
      2     call_stan_build, schema="mu float")
----> 4 df_out.show()

File /databricks/spark/python/pyspark/instrumentation_utils.py:47, in _wrap_function.<locals>.wrapper(*args, **kwargs)
     45 start = time.perf_counter()
     46 try:
---> 47     res = func(*args, **kwargs)
     48     logger.log_success(
     49         module_name, class_name, function_name, time.perf_counter() - start, signature
     50     )
     51     return res

File /databricks/spark/python/pyspark/sql/dataframe.py:1061, in DataFrame.show(self, n, truncate, vertical)
    972 def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None:
    973     """
    974     Prints the first ``n`` rows of the DataFrame to the console.
    975 
   (...)
   1059     name | This is a super l...
   1060     """
-> 1061     print(self._show_string(n, truncate, vertical))

File /databricks/spark/python/pyspark/sql/dataframe.py:1079, in DataFrame._show_string(self, n, truncate, vertical)
   1073     raise PySparkTypeError(
   1074         error_class="NOT_BOOL",
   1075         message_parameters={"arg_name": "vertical", "arg_type": type(vertical).__name__},
   1076     )
   1078 if isinstance(truncate, bool) and truncate:
-> 1079     return self._jdf.showString(n, 20, vertical)
   1080 else:
   1081     try:

File /databricks/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py:1355, in JavaMember.__call__(self, *args)
   1349 command = proto.CALL_COMMAND_NAME +\
   1350     self.command_header +\
   1351     args_command +\
   1352     proto.END_COMMAND_PART
   1354 answer = self.gateway_client.send_command(command)
-> 1355 return_value = get_return_value(
   1356     answer, self.gateway_client, self.target_id, self.name)
   1358 for temp_arg in temp_args:
   1359     if hasattr(temp_arg, "_detach"):

File /databricks/spark/python/pyspark/errors/exceptions/captured.py:230, in capture_sql_exception.<locals>.deco(*a, **kw)
    226 converted = convert_exception(e.java_exception)
    227 if not isinstance(converted, UnknownException):
    228     # Hide where the exception came from that shows a non-Pythonic
    229     # JVM exception message.
--> 230     raise converted from None
    231 else:
    232     raise

PythonException: 
  An exception was thrown from the Python worker. Please see the stack trace below.
Traceback (most recent call last):
  File "/root/.ipykernel/1581/command-1383263685039422-3633857622", line 46, in call_stan_build
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-ea53a83d-2996-4ac9-b948-372eba39f941/lib/python3.10/site-packages/stan/model.py", line 520, in build
    except KeyboardInterrupt:
  File "/usr/lib/python3.10/asyncio/runners.py", line 52, in run
    loop.close()
  File "/usr/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-ea53a83d-2996-4ac9-b948-372eba39f941/lib/python3.10/site-packages/stan/model.py", line 488, in go
    raise RuntimeError(resp.json()["message"])
RuntimeError: Exception while building model extension module: `CompileError(DistutilsExecError("command '/usr/bin/x86_64-linux-gnu-gcc' failed with exit code 1"))`, traceback: `['  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-ea53a83d-2996-4ac9-b948-372eba39f941/lib/python3.10/site-packages/httpstan/views.py", line 114, in handle_create_model\n    compiler_output = await httpstan.models.build_services_extension_module(program_code)\n', '  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-ea53a83d-2996-4ac9-b948-372eba39f941/lib/python3.10/site-packages/httpstan/models.py", line 172, in build_services_extension_module\n    compiler_output = await asyncio.get_running_loop().run_in_executor(\n', '  File "/usr/lib/python3.10/concurrent/futures/thread.py", line 58, in run\n    result = self.fn(*self.args, **self.kwargs)\n', '  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-ea53a83d-2996-4ac9-b948-372eba39f941/lib/python3.10/site-packages/httpstan/build_ext.py", line 86, in run_build_ext\n    build_extension.run()\n', '  File "/databricks/python/lib/python3.10/site-packages/setuptools/command/build_ext.py", line 84, in run\n    _build_ext.run(self)\n', '  File "/databricks/python/lib/python3.10/site-packages/Cython/Distutils/old_build_ext.py", line 186, in run\n    _build_ext.build_ext.run(self)\n', '  File "/databricks/python/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 346, in run\n    self.build_extensions()\n', '  File "/databricks/python/lib/python3.10/site-packages/Cython/Distutils/old_build_ext.py", line 195, in build_extensions\n    _build_ext.build_ext.build_extensions(self)\n', '  File "/databricks/python/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 468, in build_extensions\n    self._build_extensions_serial()\n', '  File "/databricks/python/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 494, in _build_extensions_serial\n    self.build_extension(ext)\n', '  File "/databricks/python/lib/python3.10/site-packages/setuptools/command/build_ext.py", line 246, in build_extension\n    _build_ext.build_extension(self, ext)\n', '  File "/databricks/python/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 549, in build_extension\n    objects = self.compiler.compile(\n', '  File "/databricks/python/lib/python3.10/site-packages/setuptools/_distutils/ccompiler.py", line 599, in compile\n    self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts)\n', '  File "/databricks/python/lib/python3.10/site-packages/setuptools/_distutils/unixccompiler.py", line 188, in _compile\n    raise CompileError(msg)\n']`

Update.
This issue is resolved in a higher databricks runtime.

Thanks so much for updating with the solution!