from __future__ import annotations
import ast
import inspect
import re
from pydantic import BaseModel, Field
from mkstd import YamlStandard
# For PyTorch import/export support
try:
import torch.fx
import torch.nn as nn
except ImportError:
pass
__all__ = ["Input", "Layer", "Node", "NNModel", "NNModelStandard"]
[docs]
class Layer(BaseModel):
"""Specify layers."""
layer_id: str
layer_type: str
# FIXME currently handled as kwargs
args: dict | None = Field(
default=None
) # TODO class of layer-specific supported args
[docs]
class Node(BaseModel):
"""A node of the computational graph.
e.g. a node in the forward call of a PyTorch model.
Ref: https://pytorch.org/docs/stable/fx.html#torch.fx.Node
"""
name: str
op: str
target: str
args: list | None = Field(default=None)
kwargs: dict | None = Field(default=None)
# Some customization needs to be manually extracted. These are things
# that might only be present in `Module.__repr__` output conditionally,
# usually because pytorch avoids outputting default values in its `__repr__`.
# We will explicitly save them so tools don't need to find out what pytorch
# defaults are.
extra_repr = {
"Conv2d": {
# Missing from `Conv2d.__init__`: `output_padding`
# "output_padding": lambda m: m.output_padding,
"__all__": ["padding", "dilation", "groups", "bias", "padding_mode"],
"getters": {
"bias": lambda m: m.bias is not None,
},
},
"RNN": {
"__all__": [
"input_size",
"hidden_size",
"proj_size",
"num_layers",
"bias",
"batch_first",
"dropout",
"bidirectional",
],
"getters": {},
},
}
extra_repr = {
module_id: {
attr: module_def["getters"].get(
attr, lambda m, attr=attr: getattr(m, attr)
)
for attr in module_def["__all__"]
}
for module_id, module_def in extra_repr.items()
}
# Default PyTorch kwargs, to ensure that these are saved to YAML.
default_kwargs = {
# torch.cat
"cat": {
"dim": 0,
},
}
def extract_module_args(module: nn.Module) -> dict:
"""Get the arguments used to create the module.
N.B.: currently, all arguments must be Python literals compatible
with `ast.literal_eval`, and cannot contain the character `=`.
Args:
module:
The model.
Returns:
The arguments, as keyword arguments.
"""
# Stage 1: get all arguments in the intersection of `__constants__` and
# `__init__`
init_arg_names = set(inspect.signature(module.__init__).parameters)
constant_init_args = {
arg: getattr(module, arg)
for arg in module.__constants__
if arg in init_arg_names
}
# Stage 2: add all arguments suggested in the `__repr__
## Get names of module positional arguments
arg_names = []
for arg in inspect.signature(module.__init__).parameters.values():
if arg.name == "self":
continue
if arg.default != inspect.Parameter.empty:
continue
if arg.kind not in [
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
]:
continue
arg_names.append(arg.name)
## Extract all arguments
class_name, all_args_str = re.match(
r"(\w+)\((.*)\)", repr(module).strip()
).groups()
## All positional arguments exist
args = {}
args_str_list = []
for arg_str in all_args_str.split(","):
if "=" in arg_str:
break
args_str_list.append(arg_str)
if args_str_list:
args = dict(
zip(arg_names, ast.literal_eval(",".join(args_str_list) + ","))
)
kwargs = {
kw: ast.literal_eval(arg_str.strip())
for kw, arg_str in re.findall(
r"(\w*)=([^=]*)(?=,\s*\w+=|$)", all_args_str
)
}
# See `extra_repr`
extra = {
kw: getter(module)
for kw, getter in extra_repr.get(
get_module_layer_type(module), {}
).items()
}
return constant_init_args | extra | args | kwargs
def get_module_layer_type(module: nn.Module) -> str:
"""Get the layer type of a module.
Args:
module:
The module.
Returns:
The module type.
"""
return type(module).__name__
def _convert_arg_to_str(arg, pytorch_nodes):
"""Replace modules with their IDs for serialization."""
if isinstance(arg, (list, tuple)):
return [
_convert_arg_to_str(arg=arg_i, pytorch_nodes=pytorch_nodes)
for arg_i in arg
]
return arg if arg not in pytorch_nodes else str(arg)
def _convert_str_to_arg(string, state):
"""Replace module IDs with modules for deserialization."""
if isinstance(string, list):
return [
_convert_str_to_arg(string=string_i, state=state)
for string_i in string
]
return string if string not in state else state[string]
[docs]
class NNModel(BaseModel):
"""An easy-to-use format to specify simple deep NN models.
There is a function to export this to a PyTorch module, or to YAML.
"""
nn_model_id: str
inputs: list[Input]
layers: list[Layer]
"""The components of the model (e.g., layers of a neural network)."""
forward: list[Node]
[docs]
@staticmethod
def from_pytorch_module(
module: nn.Module, nn_model_id: str, inputs: list[Input] = None,
) -> NNModel:
"""Create a PEtab SciML NN model from a pytorch module.
If `inputs` are not provided, their info will be generated from the
network.
"""
layers = []
layer_ids = []
for layer_id, layer_module in module.named_modules():
if not layer_id:
# first entry is all modules combined
continue
layer = Layer(
layer_id=layer_id,
layer_type=get_module_layer_type(layer_module),
args=extract_module_args(module=layer_module),
)
layers.append(layer)
layer_ids.append(layer_id)
nodes = []
node_names = []
pytorch_nodes = list(torch.fx.symbolic_trace(module).graph.nodes)
generate_inputs = False
if inputs is None:
generate_inputs = True
inputs = []
for pytorch_node in pytorch_nodes:
op = pytorch_node.op
target = pytorch_node.target
if op == "call_function":
target = pytorch_node.target.__name__
if op == "placeholder" and generate_inputs:
inputs.append(Input(input_id=pytorch_node.target))
# Convert module args to strings
args = [
_convert_arg_to_str(arg=arg, pytorch_nodes=pytorch_nodes)
for arg in pytorch_node.args
]
kwargs = default_kwargs.get(target, {}) | pytorch_node.kwargs
node = Node(
name=pytorch_node.name,
op=op,
target=target,
args=args,
kwargs=kwargs,
)
nodes.append(node)
node_names.append(node.name)
nn_model = NNModel(
nn_model_id=nn_model_id,
inputs=inputs,
layers=layers,
forward=nodes,
)
return nn_model
[docs]
def to_pytorch_module(self) -> nn.Module:
"""Create a pytorch module from a PEtab SciML NN model."""
self2 = self
class _PytorchModule(nn.Module):
def __init__(self) -> None:
super().__init__()
for layer in self2.layers:
setattr(
self,
layer.layer_id,
getattr(nn, layer.layer_type)(**layer.args),
)
graph = torch.fx.Graph()
state = {}
for node in self.forward:
args = []
if node.args:
# Convert strings to modules
args = [
_convert_str_to_arg(string=node_arg, state=state)
for node_arg in node.args
]
args = tuple(args)
kwargs = {}
if node.kwargs:
kwargs = {
k: _convert_str_to_arg(string=v, state=state)
for k, v in node.kwargs.items()
}
match node.op:
case "placeholder":
state[node.name] = graph.placeholder(node.target)
case "call_function":
if node.target in ["flatten", "cat"]:
function = getattr(torch, node.target)
else:
function = getattr(nn.functional, node.target)
state[node.name] = graph.call_function(
function, args, kwargs
)
case "call_method":
state[node.name] = graph.call_method(
node.target, args, kwargs
)
case "call_module":
state[node.name] = graph.call_module(
node.target, args, kwargs
)
case "output":
graph.output(args[0])
return torch.fx.GraphModule(_PytorchModule(), graph)
NNModelStandard = YamlStandard(model=NNModel)
if __name__ == "__main__":
from pathlib import Path
from mkstd import JsonStandard
# The NN model file is YAML-formatted, so the schema is provided in YAML format.
NNModelStandard.save_schema(
Path(__file__).resolve().parents[4]
/ "doc" / "standard" / "nn_model_schema.yaml"
)
# However, the schema format is JsonSchema, so the schema is also provided redundantly in JSON schema.
NNModelStandardJson = JsonStandard(model=NNModel)
NNModelStandardJson.save_schema(
Path(__file__).resolve().parents[4]
/ "doc" / "standard" / "nn_model_schema.json"
)