Creating Custom Transformation Passes
Create custom transformation and analysis passes using base classes or decorators.
Pass Types
TransformationPass
Base class for passes that modify the model.
Python
from luna_model.transformation import TransformationPass, ActionType
from luna_model.model import Model
class MyTransformationPass(TransformationPass):
"""Custom transformation pass."""
@property
def name(self) -> str:
return "my-transformation"
def run(self, model, cache):
# Transform the model
# Return (model, action_type) or (model, action_type, analysis_data)
return model, ActionType.DID_TRANSFORM
def backwards(self, solution, cache):
# Map solution back to original model variables
return solution
AnalysisPass
Base class for passes that analyze without modification.
Python
from luna_model.transformation import AnalysisPass
class MyAnalysisPass(AnalysisPass):
"""Custom analysis pass."""
@property
def name(self) -> str:
return "my-analysis"
def run(self, model, cache):
# Analyze and return data
return {
"num_vars": model.num_variables(),
}
Using Decorators
@transform Decorator
Create transformation passes from functions.
Python
from luna_model.transformation import transform, ActionType
@transform(name="scale-objective")
def scale_objective(model, cache):
"""Scale objective by 2."""
model.objective = model.objective * 2.0
return model, ActionType.DID_TRANSFORM
# Use in PassManager
pm = PassManager([scale_objective])
ir = pm.run(model)
@analyse Decorator
Create analysis passes from functions.
Python
from luna_model.transformation import analyse
@analyse(name="count-variables")
def count_vars(model, cache):
"""Count variables by type."""
from luna_model import Vtype
counts = {}
for var in model.variables():
vtype = var.vtype
counts[vtype] = counts.get(vtype, 0) + 1
return counts
# Use in PassManager
pm = PassManager([count_vars])
ir = pm.run(model)
Simple Example
Python
from luna_model.transformation import analyse, transform, ActionType, PassManager
from luna_model import Vtype
# Analysis pass using decorator
@analyse(name="count-binary-vars")
def count_binary(model, cache):
"""Count binary variables."""
count = sum(1 for v in model.variables() if v.vtype == Vtype.BINARY)
return count
# Transformation pass using decorator
@transform(name="double-objective")
def double_obj(model, cache):
"""Double the objective coefficients."""
model.objective = model.objective * 2
return model, ActionType.DID_TRANSFORM
# Use in pipeline
pm = PassManager([count_binary, double_obj])
ir = pm.run(model)
Pass Dependencies
Passes can declare dependencies using the requires property.
Python
from luna_model.transformation import analyse, transform, ActionType
# Analysis pass
@analyse(name="max-coefficient")
def find_max_coeff(model, cache):
"""Find maximum objective coefficient."""
max_val = 0
# ... compute max_val ...
return max_val
# Transformation that requires the analysis
@transform(name="scale-by-max", requires=["max-coefficient"])
def scale_by_max(model, cache):
"""Scale objective by max coefficient."""
max_coeff = cache["max-coefficient"]
if max_coeff > 0:
model.objective = model.objective / max_coeff
return model, ActionType.DID_TRANSFORM
# PassManager will run max-coefficient before scale-by-max
pm = PassManager([find_max_coeff, scale_by_max])
Backwards Transformation
All transformation passes must implement the backwards method to map solutions back to the original variable space. If your transformation doesn't change the variable mapping, simply return the solution unchanged:
Python
from luna_model.transformation import TransformationPass, ActionType
class MyTransformationPass(TransformationPass):
"""Custom transformation pass."""
@property
def name(self) -> str:
return "my-transform"
def run(self, model, cache):
"""Transform the model."""
# Apply transformation
return model, ActionType.DID_TRANSFORM
def backwards(self, solution, cache):
"""Map solution back to original variables."""
# If no variable mapping changes, return as-is
return solution
Best Practices
Use Descriptive Names
Python
@analyse(name="count-quadratic-terms")
def count_quadratic(model, cache):
"""Count quadratic terms in objective."""
# ...
pass
Document Your Passes
Python
@transform(name="normalize-coefficients")
def normalize(model, cache):
"""
Normalize objective coefficients to [-1, 1] range.
Finds the maximum absolute coefficient and scales all
coefficients accordingly.
"""
# Implementation
pass
See Also
- PassManager - Using passes
- Built-in Passes - Available passes