Add time limit to sampling

Hi all,

I would like to add a time limit to the sampling as most of the time these samples are not well-fitted in my case. Although its a a popular request here I didn’t find a suitable solution.

After looking at the source code I am thinking about altering the _create_fit() function as follows:

...
# poll to get progress for each chain until all chains finished
current_iterations = {}
started_at = time.monotonic()

while not all(operation["done"] for operation in operations):
    if time.monotonic() - started_at >= timeout:
        raise TimeoutError
...

Is that sufficient or do we need to delete cached models / fits / work done sofar?
I tried the above in a for-loop and it looks like the fits are getting clogged? Each time the sampling stops at 25% and after some time the timeout is triggered.
Any idea what’s causing this?

Full _create_fit()

def _create_fit(self, *, function, num_chains, timeout, **kwargs) -> stan.fit.Fit:
        """Make a request to httpstan's ``create_fit`` endpoint and process results.

        Users should not use this function.

        Parameters in ``kwargs`` will be passed to the (Python wrapper of)
        `function`. Parameter names are identical to those used in CmdStan.
        See the CmdStan documentation for parameter descriptions and default
        values.

        Returns:
            Fit: instance of Fit allowing access to draws.

        """
        assert "chain" not in kwargs, "`chain` id is set automatically."
        assert "data" not in kwargs, "`data` is set in `build`."
        assert "random_seed" not in kwargs, "`random_seed` is set in `build`."

        # copy kwargs and verify everything is JSON-encodable
        kwargs = json.loads(DataJSONEncoder().encode(kwargs))

        # FIXME: special handling here for `init`, consistent with PyStan 2 but needs docs
        init: List[Data] = kwargs.pop("init", [dict() for _ in range(num_chains)])
        if len(init) != num_chains:
            raise ValueError("Initial values must be provided for each chain.")

        payloads = []
        for chain in range(1, num_chains + 1):
            payload = kwargs.copy()
            payload["function"] = function
            payload["chain"] = chain  # type: ignore
            payload["data"] = self.data  # type: ignore
            payload["init"] = init.pop(0)
            if self.random_seed is not None:
                payload["random_seed"] = self.random_seed  # type: ignore

            # fit needs to know num_samples, num_warmup, num_thin, save_warmup
            # progress reporting needs to know some of these
            num_warmup = payload.get("num_warmup", arguments.lookup_default(arguments.Method["SAMPLE"], "num_warmup"))
            num_samples = payload.get(
                "num_samples",
                arguments.lookup_default(arguments.Method["SAMPLE"], "num_samples"),
            )
            num_thin = payload.get("num_thin", arguments.lookup_default(arguments.Method["SAMPLE"], "num_thin"))
            save_warmup = payload.get(
                "save_warmup",
                arguments.lookup_default(arguments.Method["SAMPLE"], "save_warmup"),
            )
            payloads.append(payload)

        async def go():
            io = ConsoleIO()
            sampling_output = io.section().error_output
            percent_complete = 0
            sampling_output.write_line(f"<comment>Sampling:</comment> {percent_complete:3.0f}%")

            current_and_max_iterations_re = re.compile(r"Iteration:\s+(\d+)\s+/\s+(\d+)")
            async with stan.common.HttpstanClient() as client:
                operations = []
                for payload in payloads:
                    resp = await client.post(f"/{self.model_name}/fits", json=payload)
                    if resp.status == 422:
                        raise ValueError(str(resp.json()))
                    elif resp.status != 201:
                        raise RuntimeError(resp.json()["message"])
                    assert resp.status == 201
                    operations.append(resp.json())

                # poll to get progress for each chain until all chains finished
                current_iterations = {}
                started_at = time.monotonic()
       
                while not all(operation["done"] for operation in operations):
                    if time.monotonic() - started_at >= timeout:                   
                        raise TimeoutError
                    for operation in operations:
                        if operation["done"]:
                            continue
                        resp = await client.get(f"/{operation['name']}")
                        assert resp.status != 404
                        operation.update(resp.json())
                        progress_message = operation["metadata"].get("progress")
                        if not progress_message:
                            continue
                        iteration, iteration_max = map(
                            int, current_and_max_iterations_re.findall(progress_message).pop(0)
                        )
                        if current_iterations.get(operation["name"]) == iteration:
                            continue
                        current_iterations[operation["name"]] = iteration
                        iterations_count = sum(current_iterations.values())
                        total_iterations = iteration_max * num_chains
                        percent_complete = 100 * iterations_count / total_iterations
                        sampling_output.clear() if io.supports_ansi() else sampling_output.write("\n")
                        sampling_output.write_line(
                            f"<comment>Sampling:</comment> {round(percent_complete):3.0f}% ({iterations_count}/{total_iterations})"
                        )
                    await asyncio.sleep(0.05)

                fit_in_cache = len(current_iterations) < num_chains

                stan_outputs = []
                for operation in operations:
                    fit_name = operation["result"].get("name")
                    if fit_name is None:  # operation["result"] is an error
                        assert not str(operation["result"]["code"]).startswith("2"), operation
                        message = operation["result"]["message"]
                        if """ValueError('Initialization failed.')""" in message:
                            sampling_output.clear()
                            sampling_output.write_line("<info>Sampling:</info> <error>Initialization failed.</error>")
                            raise RuntimeError("Initialization failed.")
                        raise RuntimeError(message)

                    resp = await client.get(f"/{fit_name}")
                    if resp.status != 200:
                        raise RuntimeError((resp.json())["message"])
                    stan_outputs.append(resp.content)

                    # clean up after ourselves when fit is uncacheable (no random seed)
                    if self.random_seed is None:
                        resp = await client.delete(f"/{fit_name}")
                        if resp.status not in {200, 202, 204}:
                            raise RuntimeError((resp.json())["message"])

                sampling_output.clear() if io.supports_ansi() else sampling_output.write("\n")
                sampling_output.write_line(
                    "<info>Sampling:</info> 100%, done."
                    if fit_in_cache
                    else f"<info>Sampling:</info> {percent_complete:3.0f}% ({iterations_count}/{total_iterations}), done."
                )
                if not io.supports_ansi():
                    sampling_output.write("\n")

            stan_outputs = tuple(stan_outputs)  # Fit constructor expects a tuple.

            def is_nonempty_logger_message(msg: simdjson.Object):
                return msg["topic"] == "logger" and msg["values"][0] != "info:"  # type: ignore

            def is_iteration_or_elapsed_time_logger_message(msg: simdjson.Object):
                # Assumes `msg` is a message with topic `logger`.
                text = msg["values"][0]  # type: ignore
                text = cast(str, text)
                return (
                    text.startswith("info:Iteration:")
                    or text.startswith("info: Elapsed Time:")
                    # this detects lines following "Elapsed Time:", part of a multi-line Stan message
                    or text.startswith("info:" + " " * 15)
                )

            parser = simdjson.Parser()
            nonstandard_logger_messages = []
            for stan_output in stan_outputs:
                for line in stan_output.splitlines():
                    # Do not attempt to parse non-logger messages. Draws could contain nan or inf values.
                    # simdjson cannot parse lines containing such values.
                    if b'"logger"' not in line:
                        continue
                    msg = parser.parse(line)
                    if is_nonempty_logger_message(msg) and not is_iteration_or_elapsed_time_logger_message(msg):
                        nonstandard_logger_messages.append(msg.as_dict())
                    del msg
            del parser  # simdjson.Parser is no longer used at this point.

            if nonstandard_logger_messages:
                io.error_line("<comment>Messages received during sampling:</comment>")
                for msg in nonstandard_logger_messages:
                    text = msg["values"][0].replace("info:", "  ").replace("error:", "  ")
                    if text.strip():
                        io.error_line(f"{text}")

            fit = stan.fit.Fit(
                stan_outputs,
                num_chains,
                self.param_names,
                self.constrained_param_names,
                self.dims,
                num_warmup,
                num_samples,
                num_thin,
                save_warmup,
            )

            for entry_point in stan.plugins.get_plugins():
                Plugin = entry_point.load()
                fit = Plugin().on_post_sample(fit)
            return fit

        try:
            return asyncio.run(go())
        except KeyboardInterrupt:
            return  # type: ignore

No idea about modifying PyStan.

I’m not sure what you mean by “sampling stops”. You mean you get something like 500 iterations but never see iteration 501? Stan’s warmup works in blocks, so it’s possible your adaptation is getting stuck with a very small step size to deal with areas of high curvature. If that’s the issue, the only good way to fix it is to reparaemterize.

Very hard to say without looking at the model. Can you simulate a small data set from the generative process of the model and fit it?