Hi all,
I would like to add a time limit to the sampling as most of the time these samples are not well-fitted in my case. Although its a a popular request here I didn’t find a suitable solution.
After looking at the source code I am thinking about altering the _create_fit()
function as follows:
...
# poll to get progress for each chain until all chains finished
current_iterations = {}
started_at = time.monotonic()
while not all(operation["done"] for operation in operations):
if time.monotonic() - started_at >= timeout:
raise TimeoutError
...
Is that sufficient or do we need to delete cached models / fits / work done sofar?
I tried the above in a for-loop
and it looks like the fits are getting clogged? Each time the sampling stops at 25% and after some time the timeout is triggered.
Any idea what’s causing this?
Full _create_fit()
def _create_fit(self, *, function, num_chains, timeout, **kwargs) -> stan.fit.Fit:
"""Make a request to httpstan's ``create_fit`` endpoint and process results.
Users should not use this function.
Parameters in ``kwargs`` will be passed to the (Python wrapper of)
`function`. Parameter names are identical to those used in CmdStan.
See the CmdStan documentation for parameter descriptions and default
values.
Returns:
Fit: instance of Fit allowing access to draws.
"""
assert "chain" not in kwargs, "`chain` id is set automatically."
assert "data" not in kwargs, "`data` is set in `build`."
assert "random_seed" not in kwargs, "`random_seed` is set in `build`."
# copy kwargs and verify everything is JSON-encodable
kwargs = json.loads(DataJSONEncoder().encode(kwargs))
# FIXME: special handling here for `init`, consistent with PyStan 2 but needs docs
init: List[Data] = kwargs.pop("init", [dict() for _ in range(num_chains)])
if len(init) != num_chains:
raise ValueError("Initial values must be provided for each chain.")
payloads = []
for chain in range(1, num_chains + 1):
payload = kwargs.copy()
payload["function"] = function
payload["chain"] = chain # type: ignore
payload["data"] = self.data # type: ignore
payload["init"] = init.pop(0)
if self.random_seed is not None:
payload["random_seed"] = self.random_seed # type: ignore
# fit needs to know num_samples, num_warmup, num_thin, save_warmup
# progress reporting needs to know some of these
num_warmup = payload.get("num_warmup", arguments.lookup_default(arguments.Method["SAMPLE"], "num_warmup"))
num_samples = payload.get(
"num_samples",
arguments.lookup_default(arguments.Method["SAMPLE"], "num_samples"),
)
num_thin = payload.get("num_thin", arguments.lookup_default(arguments.Method["SAMPLE"], "num_thin"))
save_warmup = payload.get(
"save_warmup",
arguments.lookup_default(arguments.Method["SAMPLE"], "save_warmup"),
)
payloads.append(payload)
async def go():
io = ConsoleIO()
sampling_output = io.section().error_output
percent_complete = 0
sampling_output.write_line(f"<comment>Sampling:</comment> {percent_complete:3.0f}%")
current_and_max_iterations_re = re.compile(r"Iteration:\s+(\d+)\s+/\s+(\d+)")
async with stan.common.HttpstanClient() as client:
operations = []
for payload in payloads:
resp = await client.post(f"/{self.model_name}/fits", json=payload)
if resp.status == 422:
raise ValueError(str(resp.json()))
elif resp.status != 201:
raise RuntimeError(resp.json()["message"])
assert resp.status == 201
operations.append(resp.json())
# poll to get progress for each chain until all chains finished
current_iterations = {}
started_at = time.monotonic()
while not all(operation["done"] for operation in operations):
if time.monotonic() - started_at >= timeout:
raise TimeoutError
for operation in operations:
if operation["done"]:
continue
resp = await client.get(f"/{operation['name']}")
assert resp.status != 404
operation.update(resp.json())
progress_message = operation["metadata"].get("progress")
if not progress_message:
continue
iteration, iteration_max = map(
int, current_and_max_iterations_re.findall(progress_message).pop(0)
)
if current_iterations.get(operation["name"]) == iteration:
continue
current_iterations[operation["name"]] = iteration
iterations_count = sum(current_iterations.values())
total_iterations = iteration_max * num_chains
percent_complete = 100 * iterations_count / total_iterations
sampling_output.clear() if io.supports_ansi() else sampling_output.write("\n")
sampling_output.write_line(
f"<comment>Sampling:</comment> {round(percent_complete):3.0f}% ({iterations_count}/{total_iterations})"
)
await asyncio.sleep(0.05)
fit_in_cache = len(current_iterations) < num_chains
stan_outputs = []
for operation in operations:
fit_name = operation["result"].get("name")
if fit_name is None: # operation["result"] is an error
assert not str(operation["result"]["code"]).startswith("2"), operation
message = operation["result"]["message"]
if """ValueError('Initialization failed.')""" in message:
sampling_output.clear()
sampling_output.write_line("<info>Sampling:</info> <error>Initialization failed.</error>")
raise RuntimeError("Initialization failed.")
raise RuntimeError(message)
resp = await client.get(f"/{fit_name}")
if resp.status != 200:
raise RuntimeError((resp.json())["message"])
stan_outputs.append(resp.content)
# clean up after ourselves when fit is uncacheable (no random seed)
if self.random_seed is None:
resp = await client.delete(f"/{fit_name}")
if resp.status not in {200, 202, 204}:
raise RuntimeError((resp.json())["message"])
sampling_output.clear() if io.supports_ansi() else sampling_output.write("\n")
sampling_output.write_line(
"<info>Sampling:</info> 100%, done."
if fit_in_cache
else f"<info>Sampling:</info> {percent_complete:3.0f}% ({iterations_count}/{total_iterations}), done."
)
if not io.supports_ansi():
sampling_output.write("\n")
stan_outputs = tuple(stan_outputs) # Fit constructor expects a tuple.
def is_nonempty_logger_message(msg: simdjson.Object):
return msg["topic"] == "logger" and msg["values"][0] != "info:" # type: ignore
def is_iteration_or_elapsed_time_logger_message(msg: simdjson.Object):
# Assumes `msg` is a message with topic `logger`.
text = msg["values"][0] # type: ignore
text = cast(str, text)
return (
text.startswith("info:Iteration:")
or text.startswith("info: Elapsed Time:")
# this detects lines following "Elapsed Time:", part of a multi-line Stan message
or text.startswith("info:" + " " * 15)
)
parser = simdjson.Parser()
nonstandard_logger_messages = []
for stan_output in stan_outputs:
for line in stan_output.splitlines():
# Do not attempt to parse non-logger messages. Draws could contain nan or inf values.
# simdjson cannot parse lines containing such values.
if b'"logger"' not in line:
continue
msg = parser.parse(line)
if is_nonempty_logger_message(msg) and not is_iteration_or_elapsed_time_logger_message(msg):
nonstandard_logger_messages.append(msg.as_dict())
del msg
del parser # simdjson.Parser is no longer used at this point.
if nonstandard_logger_messages:
io.error_line("<comment>Messages received during sampling:</comment>")
for msg in nonstandard_logger_messages:
text = msg["values"][0].replace("info:", " ").replace("error:", " ")
if text.strip():
io.error_line(f"{text}")
fit = stan.fit.Fit(
stan_outputs,
num_chains,
self.param_names,
self.constrained_param_names,
self.dims,
num_warmup,
num_samples,
num_thin,
save_warmup,
)
for entry_point in stan.plugins.get_plugins():
Plugin = entry_point.load()
fit = Plugin().on_post_sample(fit)
return fit
try:
return asyncio.run(go())
except KeyboardInterrupt:
return # type: ignore