(defn- covariant-derivative-argument-types
"NOTE: Returns a derivative with the same argument types as the original input
function."
[Cartan]
(let [basis (Cartan->basis Cartan)
vector-basis (b/basis->vector-basis basis)
oneform-basis (b/basis->oneform-basis basis)
Cartan-forms (Cartan->forms Cartan)]
(fn [V]
(let [CV (Cartan-forms V)]
(fn [T]
(let [arg-types (ci/argument-types T)]
(assert
(every? (fn [t]
(or (isa? t ::vf/vector-field)
(isa? t ::ff/oneform-field)))
arg-types))
(letfn [(lp [types args targs factors]
(if (empty? types)
(g/* (V (apply T targs))
(apply g/* factors))
(b/contract
(fn [e w]
(cond (isa? (first types) ::vf/vector-field)
(do (assert (vf/vector-field? (first args)))
(lp (rest types)
(rest args)
(conj targs e)
(conj factors (w (first args)))))
(isa? (first types) ::ff/oneform-field)
(do (assert (ff/oneform-field? (first args)))
(lp (rest types)
(rest args)
(conj targs w)
(conj factors ((first args) e))))))
basis)))
(the-derivative [& args]
(assert (= (count args)
(count arg-types)))
(let [argv (into [] args)
VT (lp arg-types argv [] [])
corrections (ua/generic-sum
(map-indexed
(fn [i type]
(cond
(isa? type ::ff/oneform-field)
(g/*
(g/* (s/mapr (fn [e]
((nth argv i) e))
vector-basis)
CV)
(s/mapr
(fn [w]
(apply T (assoc argv i w)))
oneform-basis))
(isa? type ::vf/vector-field)
(g/negate
(g/*
(s/mapr
(fn [e]
(apply T (assoc argv i e)))
vector-basis)
(g/* CV (s/mapr
(fn [w]
(w (nth argv i)))
oneform-basis))))))
arg-types))]
(g/+ VT corrections)))]
(ci/with-argument-types
the-derivative
arg-types))))))))