169 lines
4.3 KiB
Python
169 lines
4.3 KiB
Python
import subprocess
|
|
import threading
|
|
from pathlib import Path
|
|
import logging
|
|
import os
|
|
import time
|
|
import atexit
|
|
import json
|
|
import zmq
|
|
|
|
GLOBAL_METADATA = Path.cwd() / "global-metadata.json"
|
|
ZMQ_SOCKET = Path.cwd() / "runusb.sock"
|
|
|
|
def get_zmq_publisher():
|
|
context = zmq.Context()
|
|
socket = context.socket(zmq.PUB)
|
|
socket.bind("ipc://" + str(ZMQ_SOCKET))
|
|
return socket
|
|
|
|
def get_zmq_subscriber(topic: str):
|
|
context = zmq.Context()
|
|
socket = context.socket(zmq.SUB)
|
|
socket.setsockopt_string(zmq.SUBSCRIBE, topic)
|
|
socket.connect("ipc://" + str(ZMQ_SOCKET))
|
|
return socket
|
|
|
|
class UserCodeSupervisor(threading.Thread):
|
|
def __init__(self, process: subprocess.Popen, mountpoint: Path):
|
|
super().__init__()
|
|
self.process = process
|
|
self.logfile = mountpoint / "log.txt"
|
|
self.zmq_publisher = get_zmq_publisher()
|
|
|
|
def run(self):
|
|
with self.logfile.open(mode="w") as log_file:
|
|
while self.process.returncode is None:
|
|
stdout_data = self.process.stdout.readline()
|
|
print(stdout_data, end="")
|
|
log_file.write(stdout_data)
|
|
log_file.flush()
|
|
|
|
self.zmq_publisher.send_multipart([b"log", stdout_data.encode()])
|
|
|
|
self.process.poll()
|
|
|
|
|
|
|
|
class RunUSBRegistry:
|
|
process: subprocess.Popen | None = None
|
|
usercode_dir: Path | None = None
|
|
|
|
def __init__(self):
|
|
self.lock = threading.Lock()
|
|
self.metadata = {}
|
|
self.update_metadata({})
|
|
|
|
def start_user_code(self, mountpoint: Path):
|
|
self.usercode_dir = mountpoint
|
|
|
|
entrypoint = mountpoint / "main.py"
|
|
self.process = subprocess.Popen(
|
|
["python3", entrypoint],
|
|
env={
|
|
**os.environ,
|
|
"PYTHONUNBUFFERED": "1",
|
|
"METADATA_FILE": str(GLOBAL_METADATA)
|
|
},
|
|
cwd=mountpoint,
|
|
stdin=subprocess.DEVNULL,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.STDOUT,
|
|
universal_newlines=True,
|
|
)
|
|
|
|
UserCodeSupervisor(self.process, mountpoint).start()
|
|
|
|
def restart_user_code(self):
|
|
mountpoint = self.usercode_dir
|
|
|
|
self.stop_user_code()
|
|
self.start_user_code(mountpoint)
|
|
|
|
def stop_user_code(self):
|
|
self.process.terminate()
|
|
|
|
try:
|
|
self.process.wait(timeout=5)
|
|
except subprocess.TimeoutExpired:
|
|
pass
|
|
|
|
if self.process.poll() is not None:
|
|
self.process.kill()
|
|
|
|
# Remove the reference
|
|
self.process = None
|
|
self.usercode_dir = None
|
|
|
|
def handle_metadata_file(self, metadata_file: Path):
|
|
with metadata_file.open(mode="r") as f:
|
|
self.metadata = json.load(f)
|
|
|
|
# Notify with a blank update, as everything is new
|
|
self.update_metadata({})
|
|
|
|
def update_metadata(self, new_metadata: dict):
|
|
self.metadata.update(new_metadata)
|
|
|
|
with GLOBAL_METADATA.open("w") as f:
|
|
json.dump(self.metadata, f)
|
|
|
|
def __enter__(self):
|
|
self.lock.acquire()
|
|
|
|
def __exit__(self, *args):
|
|
self.lock.release()
|
|
|
|
def close(self):
|
|
self.stop_user_code()
|
|
|
|
|
|
class BusHandler(threading.Thread):
|
|
def __init__(self, registry: RunUSBRegistry):
|
|
super().__init__()
|
|
self.registry = registry
|
|
self.zmq_subscriber = get_zmq_subscriber("")
|
|
|
|
def run(self):
|
|
while True:
|
|
topic, message = self.zmq_subscriber.recv_multipart()
|
|
|
|
topic = topic.decode()
|
|
message = message.decode()
|
|
|
|
# if topic == "log":
|
|
# continue
|
|
|
|
if topic == "restart":
|
|
with registry:
|
|
registry.restart_user_code()
|
|
|
|
else:
|
|
print("Ignoring message with topic", topic, message.strip())
|
|
|
|
|
|
|
|
def watch_drive_activity(registry: RunUSBRegistry):
|
|
with registry:
|
|
registry.start_user_code(Path.cwd())
|
|
|
|
def watch_metadata(registry: RunUSBRegistry):
|
|
with registry:
|
|
registry.handle_metadata_file(Path.cwd() / "metadata.json")
|
|
|
|
|
|
def main():
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
|
|
registry = RunUSBRegistry()
|
|
atexit.register(registry.close)
|
|
|
|
BusHandler(registry).start()
|
|
|
|
watch_drive_activity(registry)
|
|
watch_metadata(registry)
|
|
|
|
# Threads are running now
|
|
|
|
if __name__ == '__main__':
|
|
main()
|