Commit 2cbc97e3 authored by Grégoire Uhlrich's avatar Grégoire Uhlrich
Browse files

Feynman rule file fixed

parent 9d2ac276
......@@ -66,14 +66,20 @@ class Abbrev {
static AbstractParent* find_opt(Expr const& abreviation);
static void compressAbbreviations(std::string const &name = "");
private:
static void addAbreviation(
static void compressAbbreviations_impl(
std::vector<AbstractParent*> &abbreviations
);
static void addAbbreviation(
AbstractParent* ptr,
std::string const &t_name
);
static void removeAbreviation(
static void removeAbbreviation(
AbstractParent* ptr,
std::string const &t_name
);
......@@ -180,12 +186,12 @@ class Abbreviation: public BaseParent {
baseName(t_name),
initialStructure(Abbrev::getFreeStructure(t_encapsulated))
{
Abbrev::addAbreviation(this, t_name);
Abbrev::addAbbreviation(this, t_name);
}
~Abbreviation()
{
Abbrev::removeAbreviation(this, baseName);
Abbrev::removeAbbreviation(this, baseName);
}
bool isAnAbbreviation() const override { return true; }
......
......@@ -54,6 +54,8 @@
#include "hardFactor.h"
#include "partialExpand.h"
#include "patternMatch.h"
#include "hardComparison.h"
#include "dichotomy.h"
#include "cast.h"
#include "precision_float.h"
#include "space.h"
......
#pragma once
namespace csl {
/**
* @brief Template dichotomy algorithm using a comparator.
*
* @details For the insertion of an element e, the comparator given must take
* one argument (of the same type as the range's elements) and return +1 if
* the element to insert is **simpler** than the argument, -1 if it is
* **less simple**, and 0 otherwise.
*
* @tparam Iterator Iterator type.
* @tparam Comparator Comparator type.
* @param first First iterator in the range.
* @param last Last iterator in the range.
* @param f Comparator function.
*
* @return The iterator where the element compared with $$f$$ must be inserted.
*/
template<class Iterator, class Comparator>
Iterator dichotomyFindIf(
Iterator first,
Iterator last,
Comparator &&f
)
{
while (last != first) {
const auto diff = last - first;
Iterator mid = first + diff/2;
// auto const &midExpr = v[mid]->getEncapsulated();
int comp = f(*mid);
if (comp == 1)
last = mid;
else if (comp == -1) {
if (mid == first)
++first;
else
first = mid;
}
else
return mid;
if (first + 1 == mid) {
return (f(*first) == 1) ? first : mid;
}
}
return first;
}
}
#pragma once
namespace csl {
class Expr;
int matchBOnA(csl::Expr const& A, csl::Expr &B);
bool hardComparison(csl::Expr const&, csl::Expr const&);
bool hardOrdering(csl::Expr const&, csl::Expr const&);
}
......@@ -22,6 +22,11 @@ class InitSanitizer {
public:
constexpr InitSanitizer()
{
}
constexpr InitSanitizer(char const t_name[])
:m_name(t_name)
{
......
......@@ -27,7 +27,34 @@ size_t dichoFinder(
std::vector<Node*> const &v
);
void compress(csl::Expr &expr, size_t nIter = 1);
void compress_impl(
std::vector<csl::Expr> &expr,
size_t nIter = 1
);
void compress(
std::vector<csl::Expr> &expr,
size_t nIter = 1
);
inline void compress(
csl::Expr &expr,
size_t nIter = 1
)
{
std::vector<csl::Expr> vec { expr };
compress(vec, nIter);
expr = vec[0];
}
inline void compress_impl(
csl::Expr &expr,
size_t nIter = 1
)
{
std::vector<csl::Expr> vec { expr };
compress_impl(vec, nIter);
expr = vec[0];
}
struct Node {
......
......@@ -23,6 +23,8 @@
#include "algo.h"
#include "utils.h"
#include "scopedProperty.h"
#include "hardComparison.h"
#include "dichotomy.h"
namespace csl {
......@@ -44,26 +46,36 @@ size_t dichoFinder(
std::vector<AbstractParent*> const &v
)
{
size_t first = 0;
size_t last = v.size();
while (last != first) {
size_t mid = (first + last) / 2;
auto const &midExpr = v[mid]->getEncapsulated();
if (expr < midExpr)
last = mid;
else if (midExpr < expr) {
if (mid == first)
++first;
else
first = mid;
}
else
return mid;
if (first + 1 == mid) {
return (expr < v[first]->getEncapsulated()) ? first : mid;
}
}
return first;
auto iter = csl::dichotomyFindIf(v.begin(), v.end(),
[&](AbstractParent const *parent) {
const auto &encaps = parent->getEncapsulated();
if (expr < encaps)
return +1;
else if (encaps < expr)
return -1;
return 0;
});
return iter - v.begin();
// size_t first = 0;
// size_t last = v.size();
// while (last != first) {
// size_t mid = (first + last) / 2;
// auto const &midExpr = v[mid]->getEncapsulated();
// if (expr < midExpr)
// last = mid;
// else if (midExpr < expr) {
// if (mid == first)
// ++first;
// else
// first = mid;
// }
// else
// return mid;
// if (first + 1 == mid) {
// return (expr < v[first]->getEncapsulated()) ? first : mid;
// }
// }
// return first;
}
bool Abbrev::compareParents::operator()(
......@@ -97,7 +109,7 @@ void Abbrev::cleanEmptyAbbreviation()
}
}
void Abbrev::addAbreviation(
void Abbrev::addAbbreviation(
AbstractParent* ptr,
std::string const &name
)
......@@ -107,16 +119,19 @@ void Abbrev::addAbreviation(
// "Abbreviation " + std::string(ptr->getName())
// + " already exists.");
auto &abbreviations = getAbbreviationsForName(name.data());
auto encapsulated = ptr->getEncapsulated();
size_t insertionPos;
if (useDichotomy) {
size_t pos = dichoFinder(ptr->getEncapsulated(), abbreviations);
abbreviations.insert(abbreviations.begin() + pos, ptr);
insertionPos = dichoFinder(encapsulated, abbreviations);
abbreviations.insert(abbreviations.begin() + insertionPos, ptr);
}
else {
insertionPos = abbreviations.size();
abbreviations.push_back(ptr);
}
}
void Abbrev::removeAbreviation(
void Abbrev::removeAbbreviation(
AbstractParent* ptr,
std::string const &name
)
......@@ -180,6 +195,71 @@ AbstractParent* Abbrev::find_opt(Expr const& abreviation)
return find_opt(abreviation->getName());
}
void Abbrev::compressAbbreviations(std::string const &name)
{
if (name.empty()) {
for (auto &el : abbreviationData)
compressAbbreviations_impl(el.second);
}
else {
compressAbbreviations_impl(getAbbreviationsForName(name));
}
}
static std::vector<csl::Expr> commonFactors(
csl::Expr const &prod1,
csl::Expr const &prod2
)
{
auto iter1 = prod1->begin();
auto iter2 = prod2->begin();
const auto end1 = prod1->end();
const auto end2 = prod2->end();
std::vector<csl::Expr> factors;
factors.reserve(std::min(end1 - iter1, end2 - iter2));
while (iter1 != end1 && iter2 != end2) {
if (csl::IsNumerical(*iter1) || *iter1 < *iter2) {
++iter1;
}
else if (csl::IsNumerical(*iter2) || *iter2 < *iter1) {
++iter2;
}
else {
if (*iter1 == *iter2) { // Should not need it if total order ...
factors.push_back(*iter1); // ... but we never know
}
++iter1;
++iter2;
}
}
return factors;
}
void Abbrev::compressAbbreviations_impl(
std::vector<AbstractParent*> &abbreviations
)
{
for (size_t i = 0; i != abbreviations.size(); ++i) {
csl::Expr encaps_i = abbreviations[i]->getEncapsulated();
if (!csl::IsProd(encaps_i) || csl::IsIndexed(encaps_i))
continue;
for (size_t j = i+1; j < abbreviations.size(); ++j) {
csl::Expr encaps_j = abbreviations[i]->getEncapsulated();
if (!csl::IsProd(encaps_j) || csl::IsIndexed(encaps_j))
continue;
const int diff = std::abs(
static_cast<int>(csl::Size(encaps_i))
- static_cast<int>(csl::Size(encaps_j))
);
if (diff <= 2) {
[[maybe_unused]]
std::vector<csl::Expr> factors = commonFactors(
encaps_i, encaps_j);
}
}
}
}
std::string Abbrev::getFinalName(std::string_view initialName)
{
std::string init(initialName);
......@@ -346,7 +426,7 @@ std::optional<Expr> Abbrev::findExisting(
}
else {
for (const auto& ab : abbreviations) {
Expr comparison = DeepCopy(ab->getEncapsulated());
Expr comparison = ab->getEncapsulated();
auto ab_ptr = dynamic_cast<Abbreviation<TensorParent>*>(ab);
if (not ab_ptr) {
continue;
......@@ -354,19 +434,30 @@ std::optional<Expr> Abbrev::findExisting(
if (structure.size() != ab_ptr->initialStructure.size()) {
continue;
}
auto intermediate = structure;
for (auto &i : intermediate)
i = i.rename();
for (size_t i = 0; i != ab_ptr->initialStructure.size(); ++i)
Replace(comparison,
ab_ptr->initialStructure[i],
intermediate[i],
false);
for (size_t i = 0; i != ab_ptr->initialStructure.size(); ++i)
Replace(comparison,
intermediate[i],
structure[i],
false);
if (comparison->getType() != encapsulated->getType()
|| comparison->size() != encapsulated->size()) {
continue;
}
bool diff = false;
for (size_t i = 0; i != encapsulated->size(); ++i)
if (encapsulated[i]->getType() != comparison[i]->getType()) {
diff = true;
break;
}
if (diff) {
continue;
}
csl::Replace(comparison, ab_ptr->initialStructure, structure);
//for (size_t i = 0; i != ab_ptr->initialStructure.size(); ++i)
// Replace(comparison,
// ab_ptr->initialStructure[i],
// intermediate[i],
// false);
//for (size_t i = 0; i != ab_ptr->initialStructure.size(); ++i)
// Replace(comparison,
// intermediate[i],
// structure[i],
// false);
std::map<csl::Index, csl::Index> mapping;
if (encapsulated->compareWithDummy(comparison.get(), mapping))
return (*ab)(structure.getIndex());
......@@ -411,6 +502,13 @@ Expr Abbrev::makeAbbreviation(std::string name,
Expr encaps = DeepRefreshed(encapsulated);
if (encaps->size() == 0) // nothing to abbreviate
return encaps;
if (csl::IsProd(encaps) && csl::IsNumerical(encaps[0])
&& csl::Size(encaps) > 2) {
auto prod = csl::prod_s(
std::vector<csl::Expr>(encaps->begin()+1, encaps->end()), true);
auto prodAbbrev = makeAbbreviation(name, prod, split);
return makeAbbreviation(name, prodAbbrev*encaps[0], split);
}
if (name == "Ab"
&& split
&& (csl::IsSum(encaps) || csl::IsProd(encaps))) {
......
#include "interface.h"
#include "abreviation.h"
namespace csl {
static void sortTensors(std::vector<csl::Expr> &tensors)
{
auto free = [&](csl::IndexStructure const &index) {
return csl::Abbrev::getFreeStructure(index);
};
std::sort(tensors.begin(), tensors.end());
std::reverse(tensors.begin(), tensors.end());
std::vector<csl::Expr> sorted;
sorted.reserve(tensors.size());
csl::IndexStructure contractedIndices;
auto step = [&](size_t pos) {
contractedIndices = free(
contractedIndices + tensors[pos]->getIndexStructureView());
sorted.push_back(tensors[pos]);
tensors.erase(tensors.begin() + pos);
};
while (!tensors.empty()) {
if (contractedIndices.empty()) {
step(0);
continue;
}
bool foundCommon = false;
for (size_t i = 0; i != tensors.size(); ++i) {
if (csl::IsTensorField(tensors[i]))
continue;
auto const &index = tensors[i]->getIndexStructureView();
if (contractedIndices.hasCommonIndex(index)) {
step(i);
foundCommon = true;
break;
}
}
if (!foundCommon)
step(0);
}
tensors = std::move(sorted);
}
int matchBOnA(csl::Expr const& A, csl::Expr &B)
{
std::vector<csl::Expr> tensorsInA;
std::vector<csl::Expr> tensorsInB;
csl::VisitEachLeaf(A, [&](csl::Expr const& el)
{
if (el->isIndexed())
tensorsInA.push_back(el);
});
csl::VisitEachLeaf(B, [&](csl::Expr const& el)
{
if (el->isIndexed())
tensorsInB.push_back(el);
});
if (tensorsInA.size() != tensorsInB.size()) {
return tensorsInA.size() < tensorsInB.size();
}
sortTensors(tensorsInA);
sortTensors(tensorsInB);
std::vector<std::pair<csl::Index, csl::Index>> mapping;
for (size_t i = 0; i != tensorsInA.size(); ++i) {
if (tensorsInA[i]->getParent_info()
!= tensorsInB[i]->getParent_info()) {
return tensorsInA[i]->getName() < tensorsInB[i]->getName();
}
else {
csl::IndexStructure Astruct = tensorsInA[i]->getIndexStructure();
auto last = std::remove_if(Astruct.begin(), Astruct.end(),
[&](csl::Index const &i) { return i.getFree(); });
Astruct.erase(last, Astruct.end());
csl::IndexStructure Bstruct = tensorsInB[i]->getIndexStructure();
last = std::remove_if(Bstruct.begin(), Bstruct.end(),
[&](csl::Index const &i) { return i.getFree(); });
Bstruct.erase(last, Bstruct.end());
for (size_t j = 0; j != Astruct.size(); ++j) {
auto pos = std::find_if(
mapping.begin(),
mapping.end(),
[&](std::pair<csl::Index, csl::Index> const& p)
{
return p.second == Astruct[j];
});
if (pos == mapping.end())
mapping.push_back({ Bstruct[j], Astruct[j] });
}
}
}
std::vector<csl::Index> intermediateIndices;
intermediateIndices.reserve(mapping.size());
for (const auto &mappy : mapping)
intermediateIndices.push_back(mappy.first.rename());
size_t index = 0;
for (auto& mappy : mapping) {
B = csl::Replaced(
B,
mappy.first,
intermediateIndices[index],
false);
if (mappy.first.getSpace()->getSignedIndex())
B = csl::Replaced(
B,
mappy.first.getFlipped(),
intermediateIndices[index].getFlipped(),
false);
++index;
}
index = 0;
for (auto& mappy : mapping) {
B = csl::Replaced(
B,
intermediateIndices[index],
mappy.second,
false);
if (mappy.first.getSpace()->getSignedIndex())
B = csl::Replaced(
B,
intermediateIndices[index].getFlipped(),
mappy.second.getFlipped(),
false);
++index;
}
csl::DeepRefresh(B);
return -1;
}
static bool hardComparison_impl(
csl::Expr const &A,
csl::Expr &B)
{
const int match = matchBOnA(A, B);
if (match != -1)
return false;
return A->compareWithDummy(B.get());
}
bool hardComparison(
csl::Expr const& A,
csl::Expr const& B
)
{
auto B_renameIndices = csl::DeepCopy(B);
csl::RenameIndices(B_renameIndices);
return hardComparison_impl(A, B_renameIndices);
}
static bool hardOrdering_impl(
csl::Expr const &A,
csl::Expr &B)
{
const int match = matchBOnA(A, B);
if (match != -1)
return match;
return A < B;
}
bool hardOrdering(
csl::Expr const& A,
csl::Expr const& B
)
{
auto B_renameIndices = csl::DeepCopy(B);
csl::RenameIndices(B_renameIndices);
return hardOrdering_impl(A, B_renameIndices);
}
}
......@@ -294,6 +294,9 @@ bool HardFactorImplementation(
else
innerTerms.push_back((*first) / factor);
}
else if (csl::IsProd(factor) || csl::IsPow(factor)) {
innerTerms.push_back(*first / factor);
}
else {
innerTerms.push_back((**first).suppressTerm(factor.get()));
}
......@@ -327,7 +330,7 @@ void HardFactor(Expr &init)
void DeepHardFactor(Expr &init)
{
csl::ForEachNodeReversed(init, [](csl::Expr &expr) {
csl::ForEachNode(init, [](csl::Expr &expr) {
if (csl::IsSum(expr)) {
HardFactorImplementation(expr, true);
}
......
......@@ -14,6 +14,7 @@
// along with MARTY. If not, see <https://www.gnu.org/licenses/>.
#include "librarydependency.h"
#include "abreviation.h"
#include "interface.h"
#include "indicial.h"
#include "tensorField.h"
......@@ -91,8 +92,9 @@ void DeepExpandIf_lock(
return f(node);
});
}, inplace);
if (refactor)
if (refactor) {
csl::DeepFactor(expression);
}
csl::Lock::unlock(expression, lockID);
}
......@@ -638,7 +640,13 @@ LibDependency GetLibraryDependencies(Expr const& expr)
LibDependency dependencies;
csl::VisitEachNode(expr, [&](Expr const& node)
{
dependencies += node->getLibDependency();