runusb-2/runusb.py

170 lines
4.3 KiB
Python
Raw Permalink Normal View History

2023-05-18 23:55:17 +01:00
import subprocess
import threading
from pathlib import Path
import logging
import os
import time
import atexit
import json
import zmq
GLOBAL_METADATA = Path.cwd() / "global-metadata.json"
ZMQ_SOCKET = Path.cwd() / "runusb.sock"
def get_zmq_publisher():
context = zmq.Context()
socket = context.socket(zmq.PUB)
socket.bind("ipc://" + str(ZMQ_SOCKET))
return socket
def get_zmq_subscriber(topic: str):
context = zmq.Context()
socket = context.socket(zmq.SUB)
socket.setsockopt_string(zmq.SUBSCRIBE, topic)
socket.connect("ipc://" + str(ZMQ_SOCKET))
return socket
class UserCodeSupervisor(threading.Thread):
def __init__(self, process: subprocess.Popen, mountpoint: Path):
super().__init__()
self.process = process
self.logfile = mountpoint / "log.txt"
self.zmq_publisher = get_zmq_publisher()
def run(self):
with self.logfile.open(mode="w") as log_file:
while self.process.returncode is None:
stdout_data = self.process.stdout.readline()
print(stdout_data, end="")
log_file.write(stdout_data)
log_file.flush()
self.zmq_publisher.send_multipart([b"log", stdout_data.encode()])
self.process.poll()
class RunUSBRegistry:
process: subprocess.Popen | None = None
usercode_dir: Path | None = None
def __init__(self):
self.lock = threading.Lock()
self.metadata = {}
self.update_metadata({})
def start_user_code(self, mountpoint: Path):
self.usercode_dir = mountpoint
entrypoint = mountpoint / "main.py"
self.process = subprocess.Popen(
["python3", entrypoint],
env={
**os.environ,
"PYTHONUNBUFFERED": "1",
"METADATA_FILE": str(GLOBAL_METADATA)
},
cwd=mountpoint,
stdin=subprocess.DEVNULL,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True,
)
UserCodeSupervisor(self.process, mountpoint).start()
def restart_user_code(self):
mountpoint = self.usercode_dir
self.stop_user_code()
self.start_user_code(mountpoint)
def stop_user_code(self):
self.process.terminate()
try:
self.process.wait(timeout=5)
except subprocess.TimeoutExpired:
pass
if self.process.poll() is not None:
self.process.kill()
# Remove the reference
self.process = None
self.usercode_dir = None
def handle_metadata_file(self, metadata_file: Path):
with metadata_file.open(mode="r") as f:
self.metadata = json.load(f)
# Notify with a blank update, as everything is new
self.update_metadata({})
def update_metadata(self, new_metadata: dict):
self.metadata.update(new_metadata)
with GLOBAL_METADATA.open("w") as f:
json.dump(self.metadata, f)
def __enter__(self):
self.lock.acquire()
def __exit__(self, *args):
self.lock.release()
def close(self):
self.stop_user_code()
class BusHandler(threading.Thread):
def __init__(self, registry: RunUSBRegistry):
super().__init__()
self.registry = registry
self.zmq_subscriber = get_zmq_subscriber("")
def run(self):
while True:
topic, message = self.zmq_subscriber.recv_multipart()
topic = topic.decode()
message = message.decode()
# if topic == "log":
# continue
if topic == "restart":
with registry:
registry.restart_user_code()
else:
print("Ignoring message with topic", topic, message.strip())
def watch_drive_activity(registry: RunUSBRegistry):
with registry:
registry.start_user_code(Path.cwd())
def watch_metadata(registry: RunUSBRegistry):
with registry:
registry.handle_metadata_file(Path.cwd() / "metadata.json")
def main():
logging.basicConfig(level=logging.DEBUG)
registry = RunUSBRegistry()
atexit.register(registry.close)
BusHandler(registry).start()
watch_drive_activity(registry)
watch_metadata(registry)
# Threads are running now
if __name__ == '__main__':
main()