CAT-SOOP is a flexible, programmable learning management system based on the Python programming language. https://catsoop.mit.edu
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

344 lines
10 KiB

  1. # This file is part of CAT-SOOP
  2. # Copyright (c) 2011-2020 by The CAT-SOOP Developers <catsoop-dev@mit.edu>
  3. #
  4. # This program is free software: you can redistribute it and/or modify it under
  5. # the terms of the GNU Affero General Public License as published by the Free
  6. # Software Foundation, either version 3 of the License, or (at your option) any
  7. # later version.
  8. #
  9. # This program is distributed in the hope that it will be useful, but WITHOUT
  10. # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
  11. # FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
  12. # details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <http://www.gnu.org/licenses/>.
  16. """
  17. Logging mechanisms in catsoopdb
  18. From a high-level perspective, CAT-SOOP's logs are sequences of Python objects.
  19. A log is identified by a `db_name` (typically a username), a `path` (a list of
  20. strings starting with a course name), and a `logname` (a string).
  21. On disk, each log is a file containing one or more entries, where each entry
  22. consists of:
  23. * 8 bits representing the length of the entry
  24. * a binary blob (pickled Python object, potentially encrypted and/or
  25. compressed)
  26. * the 8-bit length repeated
  27. This module provides functions for interacting with and modifying those logs.
  28. In particular, it provides ways to retrieve the Python objects in a log, or to
  29. add new Python objects to a log.
  30. """
  31. import os
  32. import ast
  33. import sys
  34. import lzma
  35. import base64
  36. import pickle
  37. import struct
  38. import hashlib
  39. import importlib
  40. import contextlib
  41. from collections import OrderedDict
  42. from datetime import datetime, timedelta
  43. from cryptography.hazmat.backends import default_backend
  44. from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
  45. from .fernet import RawFernet
  46. _nodoc = {
  47. "passthrough",
  48. "FileLock",
  49. "SEP_CHARS",
  50. "get_separator",
  51. "good_separator",
  52. "modify_most_recent",
  53. "NoneType",
  54. "OrderedDict",
  55. "datetime",
  56. "timedelta",
  57. "COMPRESS",
  58. "Cipher",
  59. "ENCRYPT_KEY",
  60. "ENCRYPT_PASS",
  61. "RawFernet",
  62. "compress_encrypt",
  63. "decompress_decrypt",
  64. "default_backend",
  65. "log_lock",
  66. "prep",
  67. "sep",
  68. "unprep",
  69. }
  70. @contextlib.contextmanager
  71. def passthrough():
  72. yield
  73. from . import base_context
  74. from filelock import FileLock
  75. importlib.reload(base_context)
  76. COMPRESS = base_context.cs_log_compression
  77. ENCRYPT_KEY = None
  78. ENCRYPT_PASS = os.environ.get("CATSOOP_PASSPHRASE", None)
  79. if ENCRYPT_PASS is not None:
  80. with open(
  81. os.path.join(os.path.dirname(os.environ["CATSOOP_CONFIG"]), "encryption_salt"),
  82. "rb",
  83. ) as f:
  84. SALT = f.read()
  85. ENCRYPT_KEY = hashlib.pbkdf2_hmac(
  86. "sha256", ENCRYPT_PASS.encode("utf8"), SALT, 100000, dklen=32
  87. )
  88. XTS_KEY = hashlib.pbkdf2_hmac("sha256", ENCRYPT_PASS.encode("utf8"), SALT, 100000)
  89. FERNET = RawFernet(ENCRYPT_KEY)
  90. def log_lock(path):
  91. lock_loc = os.path.join(base_context.cs_data_root, "_locks", *path) + ".lock"
  92. os.makedirs(os.path.dirname(lock_loc), exist_ok=True)
  93. return FileLock(lock_loc)
  94. def compress_encrypt(x):
  95. if COMPRESS:
  96. x = lzma.compress(x)
  97. if ENCRYPT_KEY is not None:
  98. x = FERNET.encrypt(x)
  99. return x
  100. def prep(x):
  101. """
  102. Helper function to serialize a Python object.
  103. """
  104. return compress_encrypt(pickle.dumps(x, -1))
  105. def decompress_decrypt(x):
  106. if ENCRYPT_KEY is not None:
  107. x = FERNET.decrypt(x)
  108. if COMPRESS:
  109. x = lzma.decompress(x)
  110. return x
  111. def unprep(x):
  112. """
  113. Helper function to deserialize a Python object.
  114. """
  115. return pickle.loads(decompress_decrypt(x))
  116. def _e(x, seed): # not sure seed is the right term here...
  117. x = x.encode("utf8") + bytes([0] * (16 - len(x)))
  118. b = hashlib.sha512(seed.encode("utf8") + ENCRYPT_KEY + SALT).digest()[-16:]
  119. c = Cipher(algorithms.AES(XTS_KEY), modes.XTS(b), backend=default_backend())
  120. e = c.encryptor()
  121. return base64.urlsafe_b64encode(e.update(x) + e.finalize()).decode("utf8")
  122. def _d(x, seed): # not sure seed is the right term here...
  123. x = base64.urlsafe_b64decode(x)
  124. b = hashlib.sha512(seed.encode("utf8") + ENCRYPT_KEY + SALT).digest()[-16:]
  125. c = Cipher(algorithms.AES(XTS_KEY), modes.XTS(b), backend=default_backend())
  126. d = c.decryptor()
  127. return (d.update(x) + d.finalize()).rstrip(b"\x00").decode("utf8")
  128. def get_log_filename(db_name, path, logname):
  129. """
  130. Helper function, returns the filename where a given log is stored on disk.
  131. **Parameters:**
  132. * `db_name`: the name of the database to look in
  133. * `path`: the path to the page associated with the log
  134. * `logname`: the name of the log
  135. """
  136. if ENCRYPT_KEY is not None:
  137. seed = path[0] if path else db_name
  138. path = [_e(i, seed + i) for i in path]
  139. db_name = _e(db_name, seed + db_name)
  140. logname = _e(logname, seed + repr(path))
  141. if path:
  142. course = path[0]
  143. return os.path.join(
  144. base_context.cs_data_root,
  145. "_logs",
  146. "_courses",
  147. course,
  148. db_name,
  149. *(path[1:]),
  150. "%s.log" % logname
  151. )
  152. else:
  153. return os.path.join(
  154. base_context.cs_data_root, "_logs", db_name, *path, "%s.log" % logname
  155. )
  156. def _modify_log(fname, new, mode):
  157. os.makedirs(os.path.dirname(fname), exist_ok=True)
  158. entry = prep(new)
  159. length = struct.pack("<Q", len(entry))
  160. with open(fname, mode) as f:
  161. f.write(length)
  162. f.write(entry)
  163. f.write(length)
  164. def update_log(db_name, path, logname, new, lock=True):
  165. """
  166. Adds a new entry to the end of the specified log.
  167. **Parameters:**
  168. * `db_name`: the name of the database to update
  169. * `path`: the path to the page associated with the log
  170. * `logname`: the name of the log
  171. * `new`: the Python object that should be added to the end of the log
  172. **Optional Parameters:**
  173. * `lock` (default `True`): whether the database should be locked during
  174. this update
  175. """
  176. fname = get_log_filename(db_name, path, logname)
  177. # get an exclusive lock on this file before making changes
  178. # look up the separator and the data
  179. cm = log_lock([db_name] + path + [logname]) if lock else passthrough()
  180. with cm:
  181. _modify_log(fname, new, "ab")
  182. def overwrite_log(db_name, path, logname, new, lock=True):
  183. """
  184. Overwrites the entire log with a new log with a single (given) entry.
  185. **Parameters:**
  186. * `db_name`: the name of the database to overwrite
  187. * `path`: the path to the page associated with the log
  188. * `logname`: the name of the log
  189. * `new`: the Python object that should be contained in the new log
  190. **Optional Parameters:**
  191. * `lock` (default `True`): whether the database should be locked during
  192. this update
  193. """
  194. # get an exclusive lock on this file before making changes
  195. fname = get_log_filename(db_name, path, logname)
  196. cm = log_lock([db_name] + path + [logname]) if lock else passthrough()
  197. with cm:
  198. _modify_log(fname, new, "wb")
  199. def _read_log(db_name, path, logname, lock=True):
  200. fname = get_log_filename(db_name, path, logname)
  201. # get an exclusive lock on this file before reading it
  202. cm = log_lock([db_name] + path + [logname]) if lock else passthrough()
  203. with cm:
  204. try:
  205. with open(fname, "rb") as f:
  206. while True:
  207. try:
  208. length = struct.unpack("<Q", f.read(8))[0]
  209. yield unprep(f.read(length))
  210. except EOFError:
  211. break
  212. f.seek(8, os.SEEK_CUR)
  213. return
  214. except:
  215. return
  216. def read_log(db_name, path, logname, lock=True):
  217. """
  218. Reads all entries of a log.
  219. **Parameters:**
  220. * `db_name`: the name of the database to read
  221. * `path`: the path to the page associated with the log
  222. * `logname`: the name of the log
  223. **Optional Parameters:**
  224. * `lock` (default `True`): whether the database should be locked during
  225. this read
  226. **Returns:** a list containing the Python objects in the log
  227. """
  228. return list(_read_log(db_name, path, logname, lock))
  229. def most_recent(db_name, path, logname, default=None, lock=True):
  230. """
  231. Ignoring most of the log, grab the last entry.
  232. This code works by reading backward through the log until the separator is
  233. found, treating the piece of the file after the last separator as a log
  234. entry, and using `unprep` to return the associated Python object.
  235. **Parameters:**
  236. * `db_name`: the name of the database to read
  237. * `path`: the path to the page associated with the log
  238. * `logname`: the name of the log
  239. **Optional Parameters:**
  240. * `default` (default `None`): the value to be returned if the log contains
  241. no entries or does not exist
  242. * `lock` (default `True`): whether the database should be locked during
  243. this read
  244. **Returns:** a single Python object representing the most recent entry in
  245. the log.
  246. """
  247. fname = get_log_filename(db_name, path, logname)
  248. if not os.path.isfile(fname):
  249. return default
  250. # get an exclusive lock on this file before reading it
  251. cm = log_lock([db_name] + path + [logname]) if lock else passthrough()
  252. with cm:
  253. with open(fname, "rb") as f:
  254. f.seek(-8, os.SEEK_END)
  255. length = struct.unpack("<Q", f.read(8))[0]
  256. f.seek(-length - 8, os.SEEK_CUR)
  257. return unprep(f.read(length))
  258. def modify_most_recent(
  259. db_name,
  260. path,
  261. logname,
  262. default=None,
  263. transform_func=lambda x: x,
  264. method="update",
  265. lock=True,
  266. ):
  267. cm = log_lock([db_name] + path + [logname]) if lock else passthrough()
  268. with cm:
  269. old_val = most_recent(db_name, path, logname, default, lock=False)
  270. new_val = transform_func(old_val)
  271. if method == "update":
  272. updater = update_log
  273. else:
  274. updater = overwrite_log
  275. updater(db_name, path, logname, new_val, lock=False)
  276. return new_val