Upgrading from pystan 2.19.1.1 to pystan 3.9.1, gives the following error when using UDF in pyspark. Running the same function without UDF works fine. This is getting executed in azure Databricks.
import stan
import nest_asyncio
from pyspark.sql import functions as F
from pyspark.sql import types as T
nest_asyncio.apply()
df_input = spark.createDataFrame(
[
(8, [28, 8, -3, 7, -1, 1, 18, 12], [15, 10, 16, 11, 9, 11, 10, 18], 'West_Sobeys'),
(7, [28, 8, -3, 7, -1, 1, 18, 12], [15, 10, 16, 11, 9, 11, 10, 21], 'Ontario_Sobeys')
],
('J', 'y', 'sigma', 'exec_id')
)
pdf_input = spark.createDataFrame(
[
(8, [28, 8, -3, 7, -1, 1, 18, 12], [15, 10, 16, 11, 9, 11, 10, 18]),
],
('J', 'y', 'sigma')
).toPandas()
schools_code = """
data {
int<lower=0> J; // number of schools
array[J] real y; // estimated treatment effects
array[J] real<lower=0> sigma; // standard error of effect estimates
}
parameters {
real mu; // population treatment effect
real<lower=0> tau; // standard deviation in treatment effects
vector[J] eta; // unscaled deviation from mu by school
}
transformed parameters {
vector[J] theta = mu + tau * eta; // school treatment effects
}
model {
target += normal_lpdf(eta | 0, 1); // prior log-density
target += normal_lpdf(y | theta, sigma); // log-likelihood
}
"""
def call_stan_build(pdf):
posterior = stan.build(schools_code, data=pdf.to_dict(orient='records')[0])
fit = posterior.sample(num_chains=4, num_samples=1000)
eta = fit["eta"] # array with shape (8, 4000)
df = fit.to_frame() # pandas `DataFrame, requires pandas
return df[['mu']]
This works
pdf_out = call_stan_build(pdf_input)
pdf_out.show()
This doesn’t works
df_out = df_input.groupby("exec_id").applyInPandas(
call_stan_build, schema="mu float")
df_out.show()
Here is the error:
PythonException Traceback (most recent call last)
File <command-1383263685039439>, line 4
1 df_out = df_input.groupby("exec_id").applyInPandas(
2 call_stan_build, schema="mu float")
----> 4 df_out.show()
File /databricks/spark/python/pyspark/instrumentation_utils.py:47, in _wrap_function.<locals>.wrapper(*args, **kwargs)
45 start = time.perf_counter()
46 try:
---> 47 res = func(*args, **kwargs)
48 logger.log_success(
49 module_name, class_name, function_name, time.perf_counter() - start, signature
50 )
51 return res
File /databricks/spark/python/pyspark/sql/dataframe.py:1061, in DataFrame.show(self, n, truncate, vertical)
972 def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None:
973 """
974 Prints the first ``n`` rows of the DataFrame to the console.
975
(...)
1059 name | This is a super l...
1060 """
-> 1061 print(self._show_string(n, truncate, vertical))
File /databricks/spark/python/pyspark/sql/dataframe.py:1079, in DataFrame._show_string(self, n, truncate, vertical)
1073 raise PySparkTypeError(
1074 error_class="NOT_BOOL",
1075 message_parameters={"arg_name": "vertical", "arg_type": type(vertical).__name__},
1076 )
1078 if isinstance(truncate, bool) and truncate:
-> 1079 return self._jdf.showString(n, 20, vertical)
1080 else:
1081 try:
File /databricks/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py:1355, in JavaMember.__call__(self, *args)
1349 command = proto.CALL_COMMAND_NAME +\
1350 self.command_header +\
1351 args_command +\
1352 proto.END_COMMAND_PART
1354 answer = self.gateway_client.send_command(command)
-> 1355 return_value = get_return_value(
1356 answer, self.gateway_client, self.target_id, self.name)
1358 for temp_arg in temp_args:
1359 if hasattr(temp_arg, "_detach"):
File /databricks/spark/python/pyspark/errors/exceptions/captured.py:230, in capture_sql_exception.<locals>.deco(*a, **kw)
226 converted = convert_exception(e.java_exception)
227 if not isinstance(converted, UnknownException):
228 # Hide where the exception came from that shows a non-Pythonic
229 # JVM exception message.
--> 230 raise converted from None
231 else:
232 raise
PythonException:
An exception was thrown from the Python worker. Please see the stack trace below.
Traceback (most recent call last):
File "/root/.ipykernel/1581/command-1383263685039422-3633857622", line 46, in call_stan_build
File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-ea53a83d-2996-4ac9-b948-372eba39f941/lib/python3.10/site-packages/stan/model.py", line 520, in build
except KeyboardInterrupt:
File "/usr/lib/python3.10/asyncio/runners.py", line 52, in run
loop.close()
File "/usr/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
return future.result()
File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-ea53a83d-2996-4ac9-b948-372eba39f941/lib/python3.10/site-packages/stan/model.py", line 488, in go
raise RuntimeError(resp.json()["message"])
RuntimeError: Exception while building model extension module: `CompileError(DistutilsExecError("command '/usr/bin/x86_64-linux-gnu-gcc' failed with exit code 1"))`, traceback: `[' File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-ea53a83d-2996-4ac9-b948-372eba39f941/lib/python3.10/site-packages/httpstan/views.py", line 114, in handle_create_model\n compiler_output = await httpstan.models.build_services_extension_module(program_code)\n', ' File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-ea53a83d-2996-4ac9-b948-372eba39f941/lib/python3.10/site-packages/httpstan/models.py", line 172, in build_services_extension_module\n compiler_output = await asyncio.get_running_loop().run_in_executor(\n', ' File "/usr/lib/python3.10/concurrent/futures/thread.py", line 58, in run\n result = self.fn(*self.args, **self.kwargs)\n', ' File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-ea53a83d-2996-4ac9-b948-372eba39f941/lib/python3.10/site-packages/httpstan/build_ext.py", line 86, in run_build_ext\n build_extension.run()\n', ' File "/databricks/python/lib/python3.10/site-packages/setuptools/command/build_ext.py", line 84, in run\n _build_ext.run(self)\n', ' File "/databricks/python/lib/python3.10/site-packages/Cython/Distutils/old_build_ext.py", line 186, in run\n _build_ext.build_ext.run(self)\n', ' File "/databricks/python/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 346, in run\n self.build_extensions()\n', ' File "/databricks/python/lib/python3.10/site-packages/Cython/Distutils/old_build_ext.py", line 195, in build_extensions\n _build_ext.build_ext.build_extensions(self)\n', ' File "/databricks/python/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 468, in build_extensions\n self._build_extensions_serial()\n', ' File "/databricks/python/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 494, in _build_extensions_serial\n self.build_extension(ext)\n', ' File "/databricks/python/lib/python3.10/site-packages/setuptools/command/build_ext.py", line 246, in build_extension\n _build_ext.build_extension(self, ext)\n', ' File "/databricks/python/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 549, in build_extension\n objects = self.compiler.compile(\n', ' File "/databricks/python/lib/python3.10/site-packages/setuptools/_distutils/ccompiler.py", line 599, in compile\n self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts)\n', ' File "/databricks/python/lib/python3.10/site-packages/setuptools/_distutils/unixccompiler.py", line 188, in _compile\n raise CompileError(msg)\n']`