Coverage for src/fad/Gradients.py : 20%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python3
3from FADiff import FADiff
6class Scal:
7 """
8 A class to...
9 """
10 def __init__(self, val, der=None, parents=[], name=None, new_input=False):
11 """
12 Constructs all the...
14 Parameters
15 ----------
16 val : float
17 value of the scalar variable
18 der : float, dictionary
19 derivative of the scalar variable
20 parents : list of Scal objects
21 the parent/grandparent vars of the variable
22 name : str
23 the name of the variable
24 new_input : boolean
25 if variable is an input variable
26 """
27 self._val = val
28 if new_input: # Creating input var?
29 self._der = {} # Add gradient dict for new var
30 for var in FADiff._fadscal_inputs: # Update gradient dicts for all vars
31 self._der[var] = 0 # Partial der of others as 0 in self
32 var._der[self] = 0 # Self's partial der as 0 in others
33 self._der[self] = der # Self's partial der in self
34 FADiff._fadscal_inputs.append(self) # Add self to global vars list
35 else:
36 self._der = der
37 self._name = name # TODO: Utilize if have time?
38 self._parents = parents
40 def __add__(self, other):
41 """
42 Adds...
44 Parameters
45 ----------
46 other : Scal, constant
47 the Scal object or constant being added to self
49 Returns
50 -------
51 new Scal instance
52 """
53 try:
54 der = {}
55 for var, part_der in self._der.items():
56 der[var] = part_der + other._der.get(var)
57 parents = self._set_parents(self, other)
58 return Scal(self._val + other._val, der, parents)
59 except AttributeError:
60 parents = self._set_parents(self)
61 return Scal(self._val + other, self._der, parents)
63 def __radd__(self, other):
64 return self.__add__(other)
66 def __mul__(self, other):
67 try:
68 der = {}
69 for var, part_der in self._der.items():
70 der[var] = self._val * other._der.get(var) +\
71 part_der * other._val
72 parents = self._set_parents(self, other)
73 return Scal(self._val * other._val, der, parents)
74 except AttributeError:
75 der = {}
76 for var, part_der in self._der.items():
77 der[var] = part_der * other
78 parents = self._set_parents(self)
79 return Scal(self._val * other, der, parents)
81 def __rmul__(self, other):
82 return self.__mul__(other)
84 @property
85 def val(self):
86 return [self._val]
88 @property
89 def der(self):
90 '''Returns partial derivatives wrt all root input vars used'''
91 parents = []
92 for var, part_der in self._der.items():
93 if var in self._parents:
94 parents.append(part_der)
95 if parents: # For output vars
96 return parents
97 elif self in FADiff._fadscal_inputs: # For input vars (no parents)
98 return [self._der[self]]
100 @staticmethod
101 def _set_parents(var1, var2=None):
102 '''Sets parent/grandparent vars (including root input vars used)'''
103 parents = []
104 parents.append(var1)
105 for parent in var1._parents:
106 parents.append(parent)
107 if var2:
108 parents.append(var2)
109 for parent in var2._parents:
110 parents.append(parent)
111 parents = list(set(parents))
112 return parents