import subprocess import threading from pathlib import Path import logging import os import time import atexit import json import zmq GLOBAL_METADATA = Path.cwd() / "global-metadata.json" ZMQ_SOCKET = Path.cwd() / "runusb.sock" def get_zmq_publisher(): context = zmq.Context() socket = context.socket(zmq.PUB) socket.bind("ipc://" + str(ZMQ_SOCKET)) return socket def get_zmq_subscriber(topic: str): context = zmq.Context() socket = context.socket(zmq.SUB) socket.setsockopt_string(zmq.SUBSCRIBE, topic) socket.connect("ipc://" + str(ZMQ_SOCKET)) return socket class UserCodeSupervisor(threading.Thread): def __init__(self, process: subprocess.Popen, mountpoint: Path): super().__init__() self.process = process self.logfile = mountpoint / "log.txt" self.zmq_publisher = get_zmq_publisher() def run(self): with self.logfile.open(mode="w") as log_file: while self.process.returncode is None: stdout_data = self.process.stdout.readline() print(stdout_data, end="") log_file.write(stdout_data) log_file.flush() self.zmq_publisher.send_multipart([b"log", stdout_data.encode()]) self.process.poll() class RunUSBRegistry: process: subprocess.Popen | None = None usercode_dir: Path | None = None def __init__(self): self.lock = threading.Lock() self.metadata = {} self.update_metadata({}) def start_user_code(self, mountpoint: Path): self.usercode_dir = mountpoint entrypoint = mountpoint / "main.py" self.process = subprocess.Popen( ["python3", entrypoint], env={ **os.environ, "PYTHONUNBUFFERED": "1", "METADATA_FILE": str(GLOBAL_METADATA) }, cwd=mountpoint, stdin=subprocess.DEVNULL, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True, ) UserCodeSupervisor(self.process, mountpoint).start() def restart_user_code(self): mountpoint = self.usercode_dir self.stop_user_code() self.start_user_code(mountpoint) def stop_user_code(self): self.process.terminate() try: self.process.wait(timeout=5) except subprocess.TimeoutExpired: pass if self.process.poll() is not None: self.process.kill() # Remove the reference self.process = None self.usercode_dir = None def handle_metadata_file(self, metadata_file: Path): with metadata_file.open(mode="r") as f: self.metadata = json.load(f) # Notify with a blank update, as everything is new self.update_metadata({}) def update_metadata(self, new_metadata: dict): self.metadata.update(new_metadata) with GLOBAL_METADATA.open("w") as f: json.dump(self.metadata, f) def __enter__(self): self.lock.acquire() def __exit__(self, *args): self.lock.release() def close(self): self.stop_user_code() class BusHandler(threading.Thread): def __init__(self, registry: RunUSBRegistry): super().__init__() self.registry = registry self.zmq_subscriber = get_zmq_subscriber("") def run(self): while True: topic, message = self.zmq_subscriber.recv_multipart() topic = topic.decode() message = message.decode() # if topic == "log": # continue if topic == "restart": with registry: registry.restart_user_code() else: print("Ignoring message with topic", topic, message.strip()) def watch_drive_activity(registry: RunUSBRegistry): with registry: registry.start_user_code(Path.cwd()) def watch_metadata(registry: RunUSBRegistry): with registry: registry.handle_metadata_file(Path.cwd() / "metadata.json") def main(): logging.basicConfig(level=logging.DEBUG) registry = RunUSBRegistry() atexit.register(registry.close) BusHandler(registry).start() watch_drive_activity(registry) watch_metadata(registry) # Threads are running now if __name__ == '__main__': main()